Skip to content

Commit dc715fe

Browse files
committed
Adding implementation of python based catalog and schema providers
1 parent e408c5a commit dc715fe

File tree

10 files changed

+358
-112
lines changed

10 files changed

+358
-112
lines changed

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import annotations
1919

2020
import pyarrow as pa
21-
2221
from datafusion import SessionContext
2322
from datafusion_ffi_example import MyCatalogProvider
2423

examples/datafusion-ffi-example/src/catalog_provider.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{any::Any, fmt::Debug, sync::Arc};
1918
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
19+
use std::{any::Any, fmt::Debug, sync::Arc};
2020

2121
use arrow::datatypes::Schema;
2222
use async_trait::async_trait;
2323
use datafusion::{
2424
catalog::{
25-
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
26-
TableProvider,
25+
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider,
2726
},
2827
common::exec_err,
2928
datasource::MemTable,
@@ -46,12 +45,12 @@ pub fn my_table() -> Arc<dyn TableProvider + 'static> {
4645
("units", Int32, vec![10, 20, 30]),
4746
("price", Float64, vec![1.0, 2.0, 5.0])
4847
)
49-
.unwrap(),
48+
.unwrap(),
5049
record_batch!(
5150
("units", Int32, vec![5, 7]),
5251
("price", Float64, vec![1.5, 2.5])
5352
)
54-
.unwrap(),
53+
.unwrap(),
5554
];
5655

5756
Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
@@ -68,9 +67,7 @@ impl Default for FixedSchemaProvider {
6867

6968
let table = my_table();
7069

71-
let _ = inner
72-
.register_table("my_table".to_string(), table)
73-
.unwrap();
70+
let _ = inner.register_table("my_table".to_string(), table).unwrap();
7471

7572
Self { inner }
7673
}
@@ -86,10 +83,7 @@ impl SchemaProvider for FixedSchemaProvider {
8683
self.inner.table_names()
8784
}
8885

89-
async fn table(
90-
&self,
91-
name: &str,
92-
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
86+
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
9387
self.inner.table(name).await
9488
}
9589

@@ -110,10 +104,13 @@ impl SchemaProvider for FixedSchemaProvider {
110104
}
111105
}
112106

113-
114107
/// This catalog provider is intended only for unit tests. It prepopulates with one
115108
/// schema and only allows for schemas named after four types of fruit.
116-
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
109+
#[pyclass(
110+
name = "MyCatalogProvider",
111+
module = "datafusion_ffi_example",
112+
subclass
113+
)]
117114
#[derive(Debug)]
118115
pub(crate) struct MyCatalogProvider {
119116
inner: MemoryCatalogProvider,
@@ -174,8 +171,9 @@ impl MyCatalogProvider {
174171
py: Python<'py>,
175172
) -> PyResult<Bound<'py, PyCapsule>> {
176173
let name = cr"datafusion_catalog_provider".into();
177-
let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);
174+
let catalog_provider =
175+
FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);
178176

179177
PyCapsule::new(py, catalog_provider, Some(name))
180178
}
181-
}
179+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::catalog_provider::MyCatalogProvider;
1819
use crate::table_function::MyTableFunction;
1920
use crate::table_provider::MyTableProvider;
20-
use crate::catalog_provider::MyCatalogProvider;
2121
use pyo3::prelude::*;
2222

23+
pub(crate) mod catalog_provider;
2324
pub(crate) mod table_function;
2425
pub(crate) mod table_provider;
25-
pub(crate) mod catalog_provider;
2626

2727
#[pymodule]
2828
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from datafusion.col import col, column
3030

31-
from . import functions, object_store, substrait, unparser
31+
from . import catalog, functions, object_store, substrait, unparser
3232

3333
# The following imports are okay to remain as opaque to the user.
3434
from ._internal import Config
@@ -87,6 +87,7 @@
8787
"TableFunction",
8888
"WindowFrame",
8989
"WindowUDF",
90+
"catalog",
9091
"col",
9192
"column",
9293
"common",

python/datafusion/catalog.py

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,99 @@
2626
if TYPE_CHECKING:
2727
import pyarrow as pa
2828

29+
try:
30+
from warnings import deprecated # Python 3.13+
31+
except ImportError:
32+
from typing_extensions import deprecated # Python 3.12
33+
34+
35+
__all__ = [
36+
"Catalog",
37+
"Schema",
38+
"Table",
39+
]
40+
2941

3042
class Catalog:
3143
"""DataFusion data catalog."""
3244

33-
def __init__(self, catalog: df_internal.Catalog) -> None:
45+
def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None:
3446
"""This constructor is not typically called by the end user."""
3547
self.catalog = catalog
3648

37-
def names(self) -> list[str]:
38-
"""Returns the list of databases in this catalog."""
39-
return self.catalog.names()
49+
def __repr__(self) -> str:
50+
"""User friendly printable display of this catalog."""
51+
return self.catalog.__repr__()
52+
53+
def names(self) -> set[str]:
54+
"""This is an alias for `schema_names`."""
55+
return self.schema_names()
56+
57+
def schema_names(self) -> set[str]:
58+
"""Returns the list of schemas in this catalog."""
59+
return self.catalog.schema_names()
60+
61+
def schema(self, name: str = "public") -> Schema:
62+
"""Returns the database with the given ``name`` from this catalog."""
63+
schema = self.catalog.schema(name)
64+
65+
return (
66+
Schema(schema)
67+
if isinstance(schema, df_internal.catalog.RawSchema)
68+
else schema
69+
)
4070

41-
def database(self, name: str = "public") -> Database:
71+
@deprecated("Use `schema` instead.")
72+
def database(self, name: str = "public") -> Schema:
4273
"""Returns the database with the given ``name`` from this catalog."""
43-
return Database(self.catalog.database(name))
74+
return self.schema(name)
75+
76+
def register_schema(self, name, schema) -> Schema | None:
77+
"""Register a schema with this catalog."""
78+
return self.catalog.register_schema(name, schema)
4479

80+
def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None:
81+
"""Deregister a schema from this catalog."""
82+
return self.catalog.deregister_schema(name, cascade)
4583

46-
class Database:
47-
"""DataFusion Database."""
4884

49-
def __init__(self, db: df_internal.Database) -> None:
85+
class Schema:
86+
"""DataFusion Schema."""
87+
88+
def __init__(self, schema: df_internal.catalog.RawSchema) -> None:
5089
"""This constructor is not typically called by the end user."""
51-
self.db = db
90+
self._raw_schema = schema
5291

5392
def names(self) -> set[str]:
54-
"""Returns the list of all tables in this database."""
55-
return self.db.names()
93+
"""This is an alias for `table_names`."""
94+
return self.table_names()
95+
96+
def table_names(self) -> set[str]:
97+
"""Returns the list of all tables in this schema."""
98+
return self._raw_schema.table_names
5699

57100
def table(self, name: str) -> Table:
58-
"""Return the table with the given ``name`` from this database."""
59-
return Table(self.db.table(name))
101+
"""Return the table with the given ``name`` from this schema."""
102+
return Table(self._raw_schema.table(name))
103+
104+
def register_table(self, name, table) -> None:
105+
"""Register a table provider in this schema."""
106+
return self._raw_schema.register_table(name, table)
107+
108+
def deregister_table(self, name: str) -> None:
109+
"""Deregister a table provider from this schema."""
110+
return self._raw_schema.deregister_table(name)
111+
112+
113+
@deprecated("Use `Schema` instead.")
114+
class Database(Schema):
115+
"""See `Schema`."""
60116

61117

62118
class Table:
63119
"""DataFusion table."""
64120

65-
def __init__(self, table: df_internal.Table) -> None:
121+
def __init__(self, table: df_internal.catalog.RawTable) -> None:
66122
"""This constructor is not typically called by the end user."""
67123
self.table = table
68124

python/tests/test_catalog.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import datafusion as dfn
1819
import pyarrow as pa
20+
import pyarrow.dataset as ds
1921
import pytest
22+
from datafusion import SessionContext, Table
2023

2124

2225
# Note we take in `database` as a variable even though we don't use
@@ -27,7 +30,7 @@ def test_basic(ctx, database):
2730
ctx.catalog("non-existent")
2831

2932
default = ctx.catalog()
30-
assert default.names() == ["public"]
33+
assert default.names() == {"public"}
3134

3235
for db in [default.database("public"), default.database()]:
3336
assert db.names() == {"csv1", "csv", "csv2"}
@@ -41,3 +44,100 @@ def test_basic(ctx, database):
4144
pa.field("float", pa.float64(), nullable=True),
4245
]
4346
)
47+
48+
49+
class CustomTableProvider:
50+
def __init__(self):
51+
pass
52+
53+
54+
def create_dataset() -> pa.dataset.Dataset:
55+
batch = pa.RecordBatch.from_arrays(
56+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
57+
names=["a", "b"],
58+
)
59+
return ds.dataset([batch])
60+
61+
62+
class CustomSchemaProvider:
63+
def __init__(self):
64+
self.tables = {"table1": create_dataset()}
65+
66+
def table_names(self) -> set[str]:
67+
return set(self.tables.keys())
68+
69+
def register_table(self, name: str, table: Table):
70+
self.tables[name] = table
71+
72+
def deregister_table(self, name, cascade: bool = True):
73+
del self.tables[name]
74+
75+
76+
class CustomCatalogProvider:
77+
def __init__(self):
78+
self.schemas = {"my_schema": CustomSchemaProvider()}
79+
80+
def schema_names(self) -> set[str]:
81+
return set(self.schemas.keys())
82+
83+
def schema(self, name: str):
84+
return self.schemas[name]
85+
86+
def register_schema(self, name: str, schema: dfn.catalog.Schema):
87+
self.schemas[name] = schema
88+
89+
def deregister_schema(self, name, cascade: bool):
90+
del self.schemas[name]
91+
92+
93+
def test_python_catalog_provider(ctx: SessionContext):
94+
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
95+
96+
# Check the default catalog provider
97+
assert ctx.catalog("datafusion").names() == {"public"}
98+
99+
my_catalog = ctx.catalog("my_catalog")
100+
assert my_catalog.names() == {"my_schema"}
101+
102+
my_catalog.register_schema("second_schema", CustomSchemaProvider())
103+
assert my_catalog.schema_names() == {"my_schema", "second_schema"}
104+
105+
my_catalog.deregister_schema("my_schema")
106+
assert my_catalog.schema_names() == {"second_schema"}
107+
108+
109+
def test_python_schema_provider(ctx: SessionContext):
110+
catalog = ctx.catalog()
111+
112+
catalog.deregister_schema("public")
113+
114+
catalog.register_schema("test_schema1", CustomSchemaProvider())
115+
assert catalog.names() == {"test_schema1"}
116+
117+
catalog.register_schema("test_schema2", CustomSchemaProvider())
118+
catalog.deregister_schema("test_schema1")
119+
assert catalog.names() == {"test_schema2"}
120+
121+
122+
def test_python_table_provider(ctx: SessionContext):
123+
catalog = ctx.catalog()
124+
125+
catalog.register_schema("custom_schema", CustomSchemaProvider())
126+
schema = catalog.schema("custom_schema")
127+
128+
assert schema.table_names() == {"table1"}
129+
130+
schema.deregister_table("table1")
131+
schema.register_table("table2", create_dataset())
132+
assert schema.table_names() == {"table2"}
133+
134+
# Use the default schema instead of our custom schema
135+
136+
schema = catalog.schema()
137+
138+
schema.register_table("table3", create_dataset())
139+
assert schema.table_names() == {"table3"}
140+
141+
schema.deregister_table("table3")
142+
schema.register_table("table4", create_dataset())
143+
assert schema.table_names() == {"table4"}

0 commit comments

Comments
 (0)