Skip to content

Commit e408c5a

Browse files
committed
Flushing out schema and catalog providers
1 parent c6d406a commit e408c5a

File tree

6 files changed

+146
-8
lines changed

6 files changed

+146
-8
lines changed

Cargo.lock

Lines changed: 23 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ substrait = ["dep:datafusion-substrait"]
3737
tokio = { version = "1.44", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3838
pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] }
3939
pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]}
40+
pyo3-log = "0.12.4"
4041
arrow = { version = "55.0.0", features = ["pyarrow"] }
4142
datafusion = { version = "47.0.0", features = ["avro", "unicode_expressions"] }
4243
datafusion-substrait = { version = "47.0.0", optional = true }
@@ -49,6 +50,7 @@ async-trait = "0.1.88"
4950
futures = "0.3"
5051
object_store = { version = "0.12.0", features = ["aws", "gcp", "azure", "http"] }
5152
url = "2"
53+
log = "0.4.27"
5254

5355
[build-dependencies]
5456
prost-types = "0.13.1" # keep in line with `datafusion-substrait`

src/catalog.rs

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion::{
2525
catalog::{CatalogProvider, SchemaProvider},
2626
datasource::{TableProvider, TableType},
2727
};
28+
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
2829
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2930
use pyo3::exceptions::PyKeyError;
3031
use pyo3::prelude::*;
@@ -48,8 +49,8 @@ pub struct PyTable {
4849
pub table: Arc<dyn TableProvider>,
4950
}
5051

51-
impl PyCatalog {
52-
pub fn new(catalog: Arc<dyn CatalogProvider>) -> Self {
52+
impl From<Arc<dyn CatalogProvider>> for PyCatalog {
53+
fn from(catalog: Arc<dyn CatalogProvider>) -> Self {
5354
Self { catalog }
5455
}
5556
}
@@ -72,6 +73,13 @@ impl PyTable {
7273

7374
#[pymethods]
7475
impl PyCatalog {
76+
#[new]
77+
fn new(catalog: PyObject) -> Self {
78+
let catalog_provider =
79+
Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc<dyn CatalogProvider>;
80+
catalog_provider.into()
81+
}
82+
7583
fn names(&self) -> Vec<String> {
7684
self.catalog.schema_names()
7785
}
@@ -286,3 +294,109 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
286294
})
287295
}
288296
}
297+
298+
#[derive(Debug)]
299+
struct RustWrappedPyCatalogProvider {
300+
catalog_provider: PyObject,
301+
}
302+
303+
impl RustWrappedPyCatalogProvider {
304+
fn new(catalog_provider: PyObject) -> Self {
305+
Self { catalog_provider }
306+
}
307+
308+
fn schema_inner(&self, name: &str) -> PyResult<Option<Arc<dyn SchemaProvider>>> {
309+
Python::with_gil(|py| {
310+
let provider = self.catalog_provider.bind(py);
311+
312+
let py_schema = provider.call_method1("schema", (name,))?;
313+
if py_schema.is_none() {
314+
return Ok(None);
315+
}
316+
317+
if py_schema.hasattr("__datafusion_schema_provider__")? {
318+
let capsule = provider
319+
.getattr("__datafusion_schema_provider__")?
320+
.call0()?;
321+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
322+
validate_pycapsule(capsule, "datafusion_schema_provider")?;
323+
324+
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
325+
let provider: ForeignSchemaProvider = provider.into();
326+
327+
Ok(Some(Arc::new(provider) as Arc<dyn SchemaProvider>))
328+
} else {
329+
let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
330+
331+
Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
332+
}
333+
})
334+
}
335+
}
336+
337+
#[async_trait]
338+
impl CatalogProvider for RustWrappedPyCatalogProvider {
339+
fn as_any(&self) -> &dyn Any {
340+
self
341+
}
342+
343+
fn schema_names(&self) -> Vec<String> {
344+
Python::with_gil(|py| {
345+
let provider = self.catalog_provider.bind(py);
346+
provider
347+
.getattr("schema_names")
348+
.and_then(|names| names.extract::<Vec<String>>())
349+
.unwrap_or_default()
350+
})
351+
}
352+
353+
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
354+
self.schema_inner(name).unwrap_or_else(|err| {
355+
log::error!("CatalogProvider schema returned error: {err}");
356+
None
357+
})
358+
}
359+
360+
fn register_schema(
361+
&self,
362+
name: &str,
363+
schema: Arc<dyn SchemaProvider>,
364+
) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
365+
let py_schema: PyDatabase = schema.into();
366+
Python::with_gil(|py| {
367+
let provider = self.catalog_provider.bind(py);
368+
let schema = provider
369+
.call_method1("register_schema", (name, py_schema))
370+
.map_err(to_datafusion_err)?;
371+
if schema.is_none() {
372+
return Ok(None);
373+
}
374+
375+
let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
376+
as Arc<dyn SchemaProvider>;
377+
378+
Ok(Some(schema))
379+
})
380+
}
381+
382+
fn deregister_schema(
383+
&self,
384+
name: &str,
385+
cascade: bool,
386+
) -> datafusion::common::Result<Option<Arc<dyn SchemaProvider>>> {
387+
Python::with_gil(|py| {
388+
let provider = self.catalog_provider.bind(py);
389+
let schema = provider
390+
.call_method1("deregister_schema", (name, cascade))
391+
.map_err(to_datafusion_err)?;
392+
if schema.is_none() {
393+
return Ok(None);
394+
}
395+
396+
let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into()))
397+
as Arc<dyn SchemaProvider>;
398+
399+
Ok(Some(schema))
400+
})
401+
}
402+
}

src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ impl PySessionContext {
842842
#[pyo3(signature = (name="datafusion"))]
843843
pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
844844
match self.ctx.catalog(name) {
845-
Some(catalog) => Ok(PyCatalog::new(catalog)),
845+
Some(catalog) => Ok(PyCatalog::from(catalog)),
846846
None => Err(PyKeyError::new_err(format!(
847847
"Catalog with name {} doesn't exist.",
848848
&name,

src/functions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
937937
m.add_wrapped(wrap_pyfunction!(left))?;
938938
m.add_wrapped(wrap_pyfunction!(length))?;
939939
m.add_wrapped(wrap_pyfunction!(ln))?;
940-
m.add_wrapped(wrap_pyfunction!(log))?;
940+
m.add_wrapped(wrap_pyfunction!(self::log))?;
941941
m.add_wrapped(wrap_pyfunction!(log10))?;
942942
m.add_wrapped(wrap_pyfunction!(log2))?;
943943
m.add_wrapped(wrap_pyfunction!(lower))?;

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
7777
/// datafusion directory.
7878
#[pymodule]
7979
fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
80+
// Initialize logging
81+
pyo3_log::init();
82+
8083
// Register the python classes
8184
m.add_class::<catalog::PyCatalog>()?;
8285
m.add_class::<catalog::PyDatabase>()?;

0 commit comments

Comments
 (0)