Skip to content

Commit 2584075

Browse files
committed
Adding implementation of python based catalog and schema providers
1 parent cadae67 commit 2584075

File tree

9 files changed

+359
-99
lines changed

9 files changed

+359
-99
lines changed

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import annotations
1919

2020
import pyarrow as pa
21-
2221
from datafusion import SessionContext
2322
from datafusion_ffi_example import MyCatalogProvider
2423

examples/datafusion-ffi-example/src/catalog_provider.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{any::Any, fmt::Debug, sync::Arc};
1918
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
19+
use std::{any::Any, fmt::Debug, sync::Arc};
2020

2121
use arrow::datatypes::Schema;
2222
use async_trait::async_trait;
2323
use datafusion::{
2424
catalog::{
25-
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
26-
TableProvider,
25+
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider,
2726
},
2827
common::exec_err,
2928
datasource::MemTable,
@@ -46,12 +45,12 @@ pub fn my_table() -> Arc<dyn TableProvider + 'static> {
4645
("units", Int32, vec![10, 20, 30]),
4746
("price", Float64, vec![1.0, 2.0, 5.0])
4847
)
49-
.unwrap(),
48+
.unwrap(),
5049
record_batch!(
5150
("units", Int32, vec![5, 7]),
5251
("price", Float64, vec![1.5, 2.5])
5352
)
54-
.unwrap(),
53+
.unwrap(),
5554
];
5655

5756
Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
@@ -68,9 +67,7 @@ impl Default for FixedSchemaProvider {
6867

6968
let table = my_table();
7069

71-
let _ = inner
72-
.register_table("my_table".to_string(), table)
73-
.unwrap();
70+
let _ = inner.register_table("my_table".to_string(), table).unwrap();
7471

7572
Self { inner }
7673
}
@@ -86,10 +83,7 @@ impl SchemaProvider for FixedSchemaProvider {
8683
self.inner.table_names()
8784
}
8885

89-
async fn table(
90-
&self,
91-
name: &str,
92-
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
86+
async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
9387
self.inner.table(name).await
9488
}
9589

@@ -110,10 +104,13 @@ impl SchemaProvider for FixedSchemaProvider {
110104
}
111105
}
112106

113-
114107
/// This catalog provider is intended only for unit tests. It prepopulates with one
115108
/// schema and only allows for schemas named after four types of fruit.
116-
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
109+
#[pyclass(
110+
name = "MyCatalogProvider",
111+
module = "datafusion_ffi_example",
112+
subclass
113+
)]
117114
#[derive(Debug)]
118115
pub(crate) struct MyCatalogProvider {
119116
inner: MemoryCatalogProvider,
@@ -174,8 +171,9 @@ impl MyCatalogProvider {
174171
py: Python<'py>,
175172
) -> PyResult<Bound<'py, PyCapsule>> {
176173
let name = cr"datafusion_catalog_provider".into();
177-
let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);
174+
let catalog_provider =
175+
FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);
178176

179177
PyCapsule::new(py, catalog_provider, Some(name))
180178
}
181-
}
179+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::catalog_provider::MyCatalogProvider;
1819
use crate::table_function::MyTableFunction;
1920
use crate::table_provider::MyTableProvider;
20-
use crate::catalog_provider::MyCatalogProvider;
2121
use pyo3::prelude::*;
2222

23+
pub(crate) mod catalog_provider;
2324
pub(crate) mod table_function;
2425
pub(crate) mod table_provider;
25-
pub(crate) mod catalog_provider;
2626

2727
#[pymodule]
2828
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from datafusion.col import col, column
3434

35-
from . import functions, object_store, substrait, unparser
35+
from . import catalog, functions, object_store, substrait, unparser
3636

3737
# The following imports are okay to remain as opaque to the user.
3838
from ._internal import Config
@@ -91,6 +91,7 @@
9191
"TableFunction",
9292
"WindowFrame",
9393
"WindowUDF",
94+
"catalog",
9495
"col",
9596
"column",
9697
"common",

python/datafusion/catalog.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,23 @@
2626
if TYPE_CHECKING:
2727
import pyarrow as pa
2828

29+
try:
30+
from warnings import deprecated # Python 3.13+
31+
except ImportError:
32+
from typing_extensions import deprecated # Python 3.12
33+
34+
35+
__all__ = [
36+
"Catalog",
37+
"Schema",
38+
"Table",
39+
]
40+
2941

3042
class Catalog:
3143
"""DataFusion data catalog."""
3244

33-
def __init__(self, catalog: df_internal.Catalog) -> None:
45+
def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None:
3446
"""This constructor is not typically called by the end user."""
3547
self.catalog = catalog
3648

@@ -59,18 +71,74 @@ def __repr__(self) -> str:
5971
return self.db.__repr__()
6072

6173
def names(self) -> set[str]:
62-
"""Returns the list of all tables in this database."""
63-
return self.db.names()
74+
"""This is an alias for `schema_names`."""
75+
return self.schema_names()
76+
77+
def schema_names(self) -> set[str]:
78+
"""Returns the list of schemas in this catalog."""
79+
return self.catalog.schema_names()
80+
81+
def schema(self, name: str = "public") -> Schema:
82+
"""Returns the database with the given ``name`` from this catalog."""
83+
schema = self.catalog.schema(name)
84+
85+
return (
86+
Schema(schema)
87+
if isinstance(schema, df_internal.catalog.RawSchema)
88+
else schema
89+
)
90+
91+
@deprecated("Use `schema` instead.")
92+
def database(self, name: str = "public") -> Schema:
93+
"""Returns the database with the given ``name`` from this catalog."""
94+
return self.schema(name)
95+
96+
def register_schema(self, name, schema) -> Schema | None:
97+
"""Register a schema with this catalog."""
98+
return self.catalog.register_schema(name, schema)
99+
100+
def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None:
101+
"""Deregister a schema from this catalog."""
102+
return self.catalog.deregister_schema(name, cascade)
103+
104+
105+
class Schema:
106+
"""DataFusion Schema."""
107+
108+
def __init__(self, schema: df_internal.catalog.RawSchema) -> None:
109+
"""This constructor is not typically called by the end user."""
110+
self._raw_schema = schema
111+
112+
def names(self) -> set[str]:
113+
"""This is an alias for `table_names`."""
114+
return self.table_names()
115+
116+
def table_names(self) -> set[str]:
117+
"""Returns the list of all tables in this schema."""
118+
return self._raw_schema.table_names
64119

65120
def table(self, name: str) -> Table:
66-
"""Return the table with the given ``name`` from this database."""
67-
return Table(self.db.table(name))
121+
"""Return the table with the given ``name`` from this schema."""
122+
return Table(self._raw_schema.table(name))
123+
124+
def register_table(self, name, table) -> None:
125+
"""Register a table provider in this schema."""
126+
return self._raw_schema.register_table(name, table)
127+
128+
def deregister_table(self, name: str) -> None:
129+
"""Deregister a table provider from this schema."""
130+
return self._raw_schema.deregister_table(name)
131+
132+
133+
@deprecated("Use `Schema` instead.")
134+
class Database(Schema):
135+
"""See `Schema`."""
68136

69137

70138
class Table:
71139
"""DataFusion table."""
72140

73-
def __init__(self, table: df_internal.Table) -> None:
141+
def __init__(self, table: df_internal.catalog.RawTable) -> None:
74142
"""This constructor is not typically called by the end user."""
75143
self.table = table
76144

python/tests/test_catalog.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import datafusion as dfn
1819
import pyarrow as pa
20+
import pyarrow.dataset as ds
1921
import pytest
22+
from datafusion import SessionContext, Table
2023

2124

2225
# Note we take in `database` as a variable even though we don't use
@@ -27,7 +30,7 @@ def test_basic(ctx, database):
2730
ctx.catalog("non-existent")
2831

2932
default = ctx.catalog()
30-
assert default.names() == ["public"]
33+
assert default.names() == {"public"}
3134

3235
for db in [default.database("public"), default.database()]:
3336
assert db.names() == {"csv1", "csv", "csv2"}
@@ -41,3 +44,100 @@ def test_basic(ctx, database):
4144
pa.field("float", pa.float64(), nullable=True),
4245
]
4346
)
47+
48+
49+
class CustomTableProvider:
50+
def __init__(self):
51+
pass
52+
53+
54+
def create_dataset() -> pa.dataset.Dataset:
55+
batch = pa.RecordBatch.from_arrays(
56+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
57+
names=["a", "b"],
58+
)
59+
return ds.dataset([batch])
60+
61+
62+
class CustomSchemaProvider:
63+
def __init__(self):
64+
self.tables = {"table1": create_dataset()}
65+
66+
def table_names(self) -> set[str]:
67+
return set(self.tables.keys())
68+
69+
def register_table(self, name: str, table: Table):
70+
self.tables[name] = table
71+
72+
def deregister_table(self, name, cascade: bool = True):
73+
del self.tables[name]
74+
75+
76+
class CustomCatalogProvider:
77+
def __init__(self):
78+
self.schemas = {"my_schema": CustomSchemaProvider()}
79+
80+
def schema_names(self) -> set[str]:
81+
return set(self.schemas.keys())
82+
83+
def schema(self, name: str):
84+
return self.schemas[name]
85+
86+
def register_schema(self, name: str, schema: dfn.catalog.Schema):
87+
self.schemas[name] = schema
88+
89+
def deregister_schema(self, name, cascade: bool):
90+
del self.schemas[name]
91+
92+
93+
def test_python_catalog_provider(ctx: SessionContext):
94+
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
95+
96+
# Check the default catalog provider
97+
assert ctx.catalog("datafusion").names() == {"public"}
98+
99+
my_catalog = ctx.catalog("my_catalog")
100+
assert my_catalog.names() == {"my_schema"}
101+
102+
my_catalog.register_schema("second_schema", CustomSchemaProvider())
103+
assert my_catalog.schema_names() == {"my_schema", "second_schema"}
104+
105+
my_catalog.deregister_schema("my_schema")
106+
assert my_catalog.schema_names() == {"second_schema"}
107+
108+
109+
def test_python_schema_provider(ctx: SessionContext):
110+
catalog = ctx.catalog()
111+
112+
catalog.deregister_schema("public")
113+
114+
catalog.register_schema("test_schema1", CustomSchemaProvider())
115+
assert catalog.names() == {"test_schema1"}
116+
117+
catalog.register_schema("test_schema2", CustomSchemaProvider())
118+
catalog.deregister_schema("test_schema1")
119+
assert catalog.names() == {"test_schema2"}
120+
121+
122+
def test_python_table_provider(ctx: SessionContext):
123+
catalog = ctx.catalog()
124+
125+
catalog.register_schema("custom_schema", CustomSchemaProvider())
126+
schema = catalog.schema("custom_schema")
127+
128+
assert schema.table_names() == {"table1"}
129+
130+
schema.deregister_table("table1")
131+
schema.register_table("table2", create_dataset())
132+
assert schema.table_names() == {"table2"}
133+
134+
# Use the default schema instead of our custom schema
135+
136+
schema = catalog.schema()
137+
138+
schema.register_table("table3", create_dataset())
139+
assert schema.table_names() == {"table3"}
140+
141+
schema.deregister_table("table3")
142+
schema.register_table("table4", create_dataset())
143+
assert schema.table_names() == {"table4"}

0 commit comments

Comments
 (0)