From 6f68f9fb7f5c2ac26195b208e81c09e0ace7f403 Mon Sep 17 00:00:00 2001 From: renato2099 Date: Sun, 1 Jun 2025 22:06:38 +0200 Subject: [PATCH 01/18] Exposing FFI to python --- examples/datafusion-ffi-example/Cargo.lock | 1 + examples/datafusion-ffi-example/Cargo.toml | 1 + .../python/tests/_test_catalog_provider.py | 63 ++++++ .../src/catalog_provider.rs | 181 ++++++++++++++++++ examples/datafusion-ffi-example/src/lib.rs | 3 + python/datafusion/context.py | 14 ++ src/context.rs | 34 ++++ 7 files changed, 297 insertions(+) create mode 100644 examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py create mode 100644 examples/datafusion-ffi-example/src/catalog_provider.rs diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a..e5a1ca8d 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -1448,6 +1448,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "datafusion", "datafusion-ffi", "pyo3", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b..31916355 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 00000000..b2c25bef --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa + +from datafusion import SessionContext +from datafusion_ffi_example import MyCatalogProvider + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ['units', 'price'] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + ctx.register_table(expected_table_name, my_table) + expected_df = ctx.sql(f"SELECT * FROM {expected_table_name}").to_pandas() + assert len(expected_df) == 5 + assert expected_table_columns == expected_df.columns.tolist() + + result = ctx.table(f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}").collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 \ No newline at end of file diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 00000000..32894ccd --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, fmt::Debug, sync::Arc}; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, + TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner + .register_table("my_table".to_string(), table) + .unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table( + &self, + name: &str, + ) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} \ No newline at end of file diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b6..76c1559e 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -17,14 +17,17 @@ use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; +use crate::catalog_provider::MyCatalogProvider; use pyo3::prelude::*; pub(crate) mod table_function; pub(crate) mod table_provider; +pub(crate) mod catalog_provider; #[pymodule] fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5b99b0d2..06593508 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -80,6 +80,14 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + def __datafusion_catalog_provider__(self) -> object: ... + + class SessionConfig: """Session configuration options.""" @@ -749,6 +757,12 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def register_catalog_provider( + self, name: str, provider: CatalogProviderExportable + ) -> None: + """Register a catalog provider.""" + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: diff --git a/src/context.rs b/src/context.rs index 6ce1f12b..350da498 100644 --- a/src/context.rs +++ b/src/context.rs @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -70,6 +71,7 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use tokio::task::JoinHandle; @@ -614,6 +616,38 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + + let option: Option> = self.ctx.register_catalog(name, Arc::new(provider)); + match option { + Some(existing) => { + println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names()); + } + None => { + println!("Catalog '{}' registered successfully", name); + } + } + + Ok(()) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_catalog_provider__ does not exist on Catalog Provider object." + .to_string(), + )) + } + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, From d2667226afe36051d5ad80ed89c03b4631cd8e31 Mon Sep 17 00:00:00 2001 From: renato2099 Date: Sun, 1 Jun 2025 23:18:39 +0200 Subject: [PATCH 02/18] Exposing FFI to python --- .../python/tests/_test_catalog_provider.py | 14 ++++++-------- python/datafusion/context.py | 5 +++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py index b2c25bef..1a57f6b4 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -22,13 +22,14 @@ from datafusion import SessionContext from datafusion_ffi_example import MyCatalogProvider + def test_catalog_provider(): ctx = SessionContext() my_catalog_name = "my_catalog" expected_schema_name = "my_schema" expected_table_name = "my_table" - expected_table_columns = ['units', 'price'] + expected_table_columns = ["units", "price"] catalog_provider = MyCatalogProvider() ctx.register_catalog_provider(my_catalog_name, catalog_provider) @@ -41,12 +42,9 @@ def test_catalog_provider(): my_table = my_database.table(expected_table_name) assert expected_table_columns == my_table.schema.names - ctx.register_table(expected_table_name, my_table) - expected_df = ctx.sql(f"SELECT * FROM {expected_table_name}").to_pandas() - assert len(expected_df) == 5 - assert expected_table_columns == expected_df.columns.tolist() - - result = ctx.table(f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}").collect() + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() assert len(result) == 2 col0_result = [r.column(0) for r in result] @@ -60,4 +58,4 @@ def test_catalog_provider(): pa.array([1.5, 2.5], type=pa.float64()), ] assert col0_result == expected_col0 - assert col1_result == expected_col1 \ No newline at end of file + assert col1_result == expected_col1 diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 06593508..c080931e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -85,7 +85,8 @@ class CatalogProviderExportable(Protocol): https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html """ - def __datafusion_catalog_provider__(self) -> object: ... + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 class SessionConfig: @@ -758,7 +759,7 @@ def deregister_table(self, name: str) -> None: self.ctx.deregister_table(name) def register_catalog_provider( - self, name: str, provider: CatalogProviderExportable + self, name: str, provider: CatalogProviderExportable ) -> None: """Register a catalog provider.""" self.ctx.register_catalog_provider(name, provider) From f1eb6bb244a45cb2d78bc8d19628cd1738d88569 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 16 Jun 2025 14:04:11 -0400 Subject: [PATCH 03/18] Workin progress on python catalog --- src/catalog.rs | 163 +++++++++++++++++++++++++++++++++++++++++++++---- src/context.rs | 15 +++-- 2 files changed, 163 insertions(+), 15 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08c..0c8c2bf8 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,19 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::utils::wait_for_future; +use crate::dataset::Dataset; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::utils::{validate_pycapsule, wait_for_future}; +use async_trait::async_trait; +use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; #[pyclass(name = "Catalog", module = "datafusion", subclass)] pub struct PyCatalog { @@ -50,8 +54,8 @@ impl PyCatalog { } } -impl PyDatabase { - pub fn new(database: Arc) -> Self { +impl From> for PyDatabase { + fn from(database: Arc) -> Self { Self { database } } } @@ -75,7 +79,7 @@ impl PyCatalog { #[pyo3(signature = (name="public"))] fn database(&self, name: &str) -> PyResult { match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), + Some(database) => Ok(database.into()), None => Err(PyKeyError::new_err(format!( "Database with name {name} doesn't exist." ))), @@ -92,6 +96,13 @@ impl PyCatalog { #[pymethods] impl PyDatabase { + #[new] + fn new(schema_provider: PyObject) -> Self { + let schema_provider = + Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc; + schema_provider.into() + } + fn names(&self) -> HashSet { self.database.table_names().into_iter().collect() } @@ -145,3 +156,133 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[derive(Debug)] +struct RustWrappedPySchemaProvider { + schema_provider: PyObject, + owner_name: Option, +} + +impl RustWrappedPySchemaProvider { + fn new(schema_provider: PyObject) -> Self { + let owner_name = Python::with_gil(|py| { + schema_provider + .bind(py) + .getattr("owner_name") + .ok() + .map(|name| name.to_string()) + }); + + Self { + schema_provider, + owner_name, + } + } + + fn table_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let py_table_method = provider.getattr("table")?; + + let py_table = py_table_method.call((name,), None)?; + if py_table.is_none() { + return Ok(None); + } + + if py_table.hasattr("__datafusion_table_provider__")? { + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + + Ok(Some(Arc::new(ds) as Arc)) + } + }) + } +} + +#[async_trait] +impl SchemaProvider for RustWrappedPySchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .getattr("table_names") + .and_then(|names| names.extract::>()) + .unwrap_or_default() + }) + } + + async fn table( + &self, + name: &str, + ) -> datafusion::common::Result>, DataFusionError> { + self.table_inner(name).map_err(to_datafusion_err) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion::common::Result>> { + let py_table = PyTable::new(table); + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let _ = provider + .call_method1("register_table", (name, py_table)) + .map_err(to_datafusion_err)?; + // Since the definition of `register_table` says that an error + // will be returned if the table already exists, there is no + // case where we want to return a table provider as output. + Ok(None) + }) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let table = provider + .call_method1("deregister_table", (name,)) + .map_err(to_datafusion_err)?; + if table.is_none() { + return Ok(None); + } + + // If we can turn this table provider into a `Dataset`, return it. + // Otherwise, return None. + let dataset = match Dataset::new(&table, py) { + Ok(dataset) => Some(Arc::new(dataset) as Arc), + Err(_) => None, + }; + + Ok(dataset) + }) + } + + fn table_exist(&self, name: &str) -> bool { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .call_method1("table_exist", (name,)) + .and_then(|pyobj| pyobj.extract()) + .unwrap_or(false) + }) + } +} diff --git a/src/context.rs b/src/context.rs index 350da498..ed635c90 100644 --- a/src/context.rs +++ b/src/context.rs @@ -70,8 +70,8 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; -use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use tokio::task::JoinHandle; @@ -622,17 +622,24 @@ impl PySessionContext { provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { if provider.hasattr("__datafusion_catalog_provider__")? { - let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?; + let capsule = provider + .getattr("__datafusion_catalog_provider__")? + .call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; validate_pycapsule(capsule, "datafusion_catalog_provider")?; let provider = unsafe { capsule.reference::() }; let provider: ForeignCatalogProvider = provider.into(); - let option: Option> = self.ctx.register_catalog(name, Arc::new(provider)); + let option: Option> = + self.ctx.register_catalog(name, Arc::new(provider)); match option { Some(existing) => { - println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names()); + println!( + "Catalog '{}' already existed, schema names: {:?}", + name, + existing.schema_names() + ); } None => { println!("Catalog '{}' registered successfully", name); From e7aaf4769ade4d0103781e3f42089b034af6c87d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 17 Jun 2025 12:57:13 -0400 Subject: [PATCH 04/18] Flushing out schema and catalog providers --- Cargo.lock | 19 ++++++++ Cargo.toml | 2 + src/catalog.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++++- src/context.rs | 2 +- src/functions.rs | 2 +- src/lib.rs | 3 ++ 6 files changed, 142 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 112167cb..a3e9336c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "zstd", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -1503,6 +1509,7 @@ dependencies = [ "datafusion-proto", "datafusion-substrait", "futures", + "log", "mimalloc", "object_store", "prost", @@ -1510,6 +1517,7 @@ dependencies = [ "pyo3", "pyo3-async-runtimes", "pyo3-build-config", + "pyo3-log", "tokio", "url", "uuid", @@ -2953,6 +2961,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.24.2" diff --git a/Cargo.toml b/Cargo.toml index 4135e64e..1f7895a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} +pyo3-log = "0.12.4" arrow = { version = "55.1.0", features = ["pyarrow"] } datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "48.0.0", optional = true } @@ -49,6 +50,7 @@ async-trait = "0.1.88" futures = "0.3" object_store = { version = "0.12.1", features = ["aws", "gcp", "azure", "http"] } url = "2" +log = "0.4.27" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/src/catalog.rs b/src/catalog.rs index 0c8c2bf8..aace7e40 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -25,6 +25,7 @@ use datafusion::{ catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; @@ -48,8 +49,8 @@ pub struct PyTable { pub table: Arc, } -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { +impl From> for PyCatalog { + fn from(catalog: Arc) -> Self { Self { catalog } } } @@ -72,6 +73,13 @@ impl PyTable { #[pymethods] impl PyCatalog { + #[new] + fn new(catalog: PyObject) -> Self { + let catalog_provider = + Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc; + catalog_provider.into() + } + fn names(&self) -> Vec { self.catalog.schema_names() } @@ -286,3 +294,109 @@ impl SchemaProvider for RustWrappedPySchemaProvider { }) } } + +#[derive(Debug)] +struct RustWrappedPyCatalogProvider { + catalog_provider: PyObject, +} + +impl RustWrappedPyCatalogProvider { + fn new(catalog_provider: PyObject) -> Self { + Self { catalog_provider } + } + + fn schema_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + + let py_schema = provider.call_method1("schema", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + if py_schema.hasattr("__datafusion_schema_provider__")? { + let capsule = provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + + Ok(Some(Arc::new(py_schema) as Arc)) + } + }) + } +} + +#[async_trait] +impl CatalogProvider for RustWrappedPyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + provider + .getattr("schema_names") + .and_then(|names| names.extract::>()) + .unwrap_or_default() + }) + } + + fn schema(&self, name: &str) -> Option> { + self.schema_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider schema returned error: {err}"); + None + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion::common::Result>> { + let py_schema: PyDatabase = schema.into(); + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("register_schema", (name, py_schema)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("deregister_schema", (name, cascade)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } +} diff --git a/src/context.rs b/src/context.rs index ed635c90..c57b6681 100644 --- a/src/context.rs +++ b/src/context.rs @@ -888,7 +888,7 @@ impl PySessionContext { #[pyo3(signature = (name="datafusion"))] pub fn catalog(&self, name: &str) -> PyResult { match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), + Some(catalog) => Ok(PyCatalog::from(catalog)), None => Err(PyKeyError::new_err(format!( "Catalog with name {} doesn't exist.", &name, diff --git a/src/functions.rs b/src/functions.rs index b2bafcb6..b40500b8 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log))?; + m.add_wrapped(wrap_pyfunction!(self::log))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; m.add_wrapped(wrap_pyfunction!(lower))?; diff --git a/src/lib.rs b/src/lib.rs index 1293eee3..414d1c36 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,6 +77,9 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// datafusion directory. #[pymodule] fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + // Initialize logging + pyo3_log::init(); + // Register the python classes m.add_class::()?; m.add_class::()?; From 062e6c27a9e4f39fe66d475903aa7360320c1899 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Jun 2025 09:24:53 -0400 Subject: [PATCH 05/18] Adding implementation of python based catalog and schema providers --- .../python/tests/_test_catalog_provider.py | 1 - .../src/catalog_provider.rs | 30 ++- examples/datafusion-ffi-example/src/lib.rs | 4 +- python/datafusion/catalog.py | 80 +++++++- python/tests/test_catalog.py | 102 +++++++++- src/catalog.rs | 176 ++++++++++++++---- src/context.rs | 55 +++--- src/lib.rs | 7 +- 8 files changed, 357 insertions(+), 98 deletions(-) diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py index 1a57f6b4..72aadf64 100644 --- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -18,7 +18,6 @@ from __future__ import annotations import pyarrow as pa - from datafusion import SessionContext from datafusion_ffi_example import MyCatalogProvider diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs index 32894ccd..54e61cf3 100644 --- a/examples/datafusion-ffi-example/src/catalog_provider.rs +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, fmt::Debug, sync::Arc}; use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion::{ catalog::{ - CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, - TableProvider, + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, }, common::exec_err, datasource::MemTable, @@ -46,12 +45,12 @@ pub fn my_table() -> Arc { ("units", Int32, vec![10, 20, 30]), ("price", Float64, vec![1.0, 2.0, 5.0]) ) - .unwrap(), + .unwrap(), record_batch!( ("units", Int32, vec![5, 7]), ("price", Float64, vec![1.5, 2.5]) ) - .unwrap(), + .unwrap(), ]; Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) @@ -68,9 +67,7 @@ impl Default for FixedSchemaProvider { let table = my_table(); - let _ = inner - .register_table("my_table".to_string(), table) - .unwrap(); + let _ = inner.register_table("my_table".to_string(), table).unwrap(); Self { inner } } @@ -86,10 +83,7 @@ impl SchemaProvider for FixedSchemaProvider { self.inner.table_names() } - async fn table( - &self, - name: &str, - ) -> Result>, DataFusionError> { + async fn table(&self, name: &str) -> Result>, DataFusionError> { self.inner.table(name).await } @@ -110,10 +104,13 @@ impl SchemaProvider for FixedSchemaProvider { } } - /// This catalog provider is intended only for unit tests. It prepopulates with one /// schema and only allows for schemas named after four types of fruit. -#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)] +#[pyclass( + name = "MyCatalogProvider", + module = "datafusion_ffi_example", + subclass +)] #[derive(Debug)] pub(crate) struct MyCatalogProvider { inner: MemoryCatalogProvider, @@ -174,8 +171,9 @@ impl MyCatalogProvider { py: Python<'py>, ) -> PyResult> { let name = cr"datafusion_catalog_provider".into(); - let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + let catalog_provider = + FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); PyCapsule::new(py, catalog_provider, Some(name)) } -} \ No newline at end of file +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 76c1559e..3a4cf224 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::catalog_provider::MyCatalogProvider; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; -use crate::catalog_provider::MyCatalogProvider; use pyo3::prelude::*; +pub(crate) mod catalog_provider; pub(crate) mod table_function; pub(crate) mod table_provider; -pub(crate) mod catalog_provider; #[pymodule] fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 67ab3ead..cbea6bd6 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -26,11 +26,23 @@ if TYPE_CHECKING: import pyarrow as pa +try: + from warnings import deprecated # Python 3.13+ +except ImportError: + from typing_extensions import deprecated # Python 3.12 + + +__all__ = [ + "Catalog", + "Schema", + "Table", +] + class Catalog: """DataFusion data catalog.""" - def __init__(self, catalog: df_internal.Catalog) -> None: + def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog @@ -59,18 +71,74 @@ def __repr__(self) -> str: return self.db.__repr__() def names(self) -> set[str]: - """Returns the list of all tables in this database.""" - return self.db.names() + """This is an alias for `schema_names`.""" + return self.schema_names() + + def schema_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog.schema_names() + + def schema(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + schema = self.catalog.schema(name) + + return ( + Schema(schema) + if isinstance(schema, df_internal.catalog.RawSchema) + else schema + ) + + @deprecated("Use `schema` instead.") + def database(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + return self.schema(name) + + def register_schema(self, name, schema) -> Schema | None: + """Register a schema with this catalog.""" + return self.catalog.register_schema(name, schema) + + def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: + """Deregister a schema from this catalog.""" + return self.catalog.deregister_schema(name, cascade) + + +class Schema: + """DataFusion Schema.""" + + def __init__(self, schema: df_internal.catalog.RawSchema) -> None: + """This constructor is not typically called by the end user.""" + self._raw_schema = schema + + def names(self) -> set[str]: + """This is an alias for `table_names`.""" + return self.table_names() + + def table_names(self) -> set[str]: + """Returns the list of all tables in this schema.""" + return self._raw_schema.table_names def table(self, name: str) -> Table: - """Return the table with the given ``name`` from this database.""" - return Table(self.db.table(name)) + """Return the table with the given ``name`` from this schema.""" + return Table(self._raw_schema.table(name)) + + def register_table(self, name, table) -> None: + """Register a table provider in this schema.""" + return self._raw_schema.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Deregister a table provider from this schema.""" + return self._raw_schema.deregister_table(name) + + +@deprecated("Use `Schema` instead.") +class Database(Schema): + """See `Schema`.""" class Table: """DataFusion table.""" - def __init__(self, table: df_internal.Table) -> None: + def __init__(self, table: df_internal.catalog.RawTable) -> None: """This constructor is not typically called by the end user.""" self.table = table diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 23b32845..21b0a3e0 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. +import datafusion as dfn import pyarrow as pa +import pyarrow.dataset as ds import pytest +from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -27,7 +30,7 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ["public"] + assert default.names() == {"public"} for db in [default.database("public"), default.database()]: assert db.names() == {"csv1", "csv", "csv2"} @@ -41,3 +44,100 @@ def test_basic(ctx, database): pa.field("float", pa.float64(), nullable=True), ] ) + + +class CustomTableProvider: + def __init__(self): + pass + + +def create_dataset() -> pa.dataset.Dataset: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ds.dataset([batch]) + + +class CustomSchemaProvider: + def __init__(self): + self.tables = {"table1": create_dataset()} + + def table_names(self) -> set[str]: + return set(self.tables.keys()) + + def register_table(self, name: str, table: Table): + self.tables[name] = table + + def deregister_table(self, name, cascade: bool = True): + del self.tables[name] + + +class CustomCatalogProvider: + def __init__(self): + self.schemas = {"my_schema": CustomSchemaProvider()} + + def schema_names(self) -> set[str]: + return set(self.schemas.keys()) + + def schema(self, name: str): + return self.schemas[name] + + def register_schema(self, name: str, schema: dfn.catalog.Schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade: bool): + del self.schemas[name] + + +def test_python_catalog_provider(ctx: SessionContext): + ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) + + # Check the default catalog provider + assert ctx.catalog("datafusion").names() == {"public"} + + my_catalog = ctx.catalog("my_catalog") + assert my_catalog.names() == {"my_schema"} + + my_catalog.register_schema("second_schema", CustomSchemaProvider()) + assert my_catalog.schema_names() == {"my_schema", "second_schema"} + + my_catalog.deregister_schema("my_schema") + assert my_catalog.schema_names() == {"second_schema"} + + +def test_python_schema_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.deregister_schema("public") + + catalog.register_schema("test_schema1", CustomSchemaProvider()) + assert catalog.names() == {"test_schema1"} + + catalog.register_schema("test_schema2", CustomSchemaProvider()) + catalog.deregister_schema("test_schema1") + assert catalog.names() == {"test_schema2"} + + +def test_python_table_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.register_schema("custom_schema", CustomSchemaProvider()) + schema = catalog.schema("custom_schema") + + assert schema.table_names() == {"table1"} + + schema.deregister_table("table1") + schema.register_table("table2", create_dataset()) + assert schema.table_names() == {"table2"} + + # Use the default schema instead of our custom schema + + schema = catalog.schema() + + schema.register_table("table3", create_dataset()) + assert schema.table_names() == {"table3"} + + schema.deregister_table("table3") + schema.register_table("table4", create_dataset()) + assert schema.table_names() == {"table4"} diff --git a/src/catalog.rs b/src/catalog.rs index aace7e40..ab957fb2 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -30,21 +30,22 @@ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; use pyo3::types::PyCapsule; +use pyo3::IntoPyObjectExt; use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -#[pyclass(name = "Catalog", module = "datafusion", subclass)] +#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub struct PyDatabase { - pub database: Arc, +#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +pub struct PySchema { + pub schema: Arc, } -#[pyclass(name = "Table", module = "datafusion", subclass)] +#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] pub struct PyTable { pub table: Arc, } @@ -55,9 +56,9 @@ impl From> for PyCatalog { } } -impl From> for PyDatabase { - fn from(database: Arc) -> Self { - Self { database } +impl From> for PySchema { + fn from(schema: Arc) -> Self { + Self { schema } } } @@ -80,30 +81,72 @@ impl PyCatalog { catalog_provider.into() } - fn names(&self) -> Vec { - self.catalog.schema_names() + fn schema_names(&self) -> HashSet { + self.catalog.schema_names().into_iter().collect() } #[pyo3(signature = (name="public"))] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(database.into()), - None => Err(PyKeyError::new_err(format!( - "Database with name {name} doesn't exist." - ))), - } + fn schema(&self, name: &str) -> PyResult { + let schema = self + .catalog + .schema(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)), + None => PySchema::from(schema).into_py_any(py), + } + }) + } + + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { + let capsule = schema_provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPySchemaProvider::new(schema_provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self + .catalog + .register_schema(name, provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> { + let _ = self + .catalog + .deregister_schema(name, cascade) + .map_err(py_datafusion_err)?; + + Ok(()) } fn __repr__(&self) -> PyResult { - Ok(format!( - "Catalog(schema_names=[{}])", - self.names().join(";") - )) + let mut names: Vec = self.schema_names().into_iter().collect(); + names.sort(); + Ok(format!("Catalog(schema_names=[{}])", names.join(", "))) } } #[pymethods] -impl PyDatabase { +impl PySchema { #[new] fn new(schema_provider: PyObject) -> Self { let schema_provider = @@ -111,8 +154,9 @@ impl PyDatabase { schema_provider.into() } - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() + #[getter] + fn table_names(&self) -> HashSet { + self.schema.table_names().into_iter().collect() } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { @@ -126,14 +170,44 @@ impl PyDatabase { } fn __repr__(&self) -> PyResult { - Ok(format!( - "Database(table_names=[{}])", - Vec::from_iter(self.names()).join(";") - )) + let mut names: Vec = self.table_names().into_iter().collect(); + names.sort(); + Ok(format!("Schema(table_names=[{}])", names.join(";"))) } - // register_table - // deregister_table + fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let capsule = table_provider + .getattr("__datafusion_table_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + }; + + let _ = self + .schema + .register_table(name.to_string(), provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_table(&self, name: &str) -> PyResult<()> { + let _ = self + .schema + .deregister_table(name) + .map_err(py_datafusion_err)?; + + Ok(()) + } } #[pymethods] @@ -166,13 +240,13 @@ impl PyTable { } #[derive(Debug)] -struct RustWrappedPySchemaProvider { +pub(crate) struct RustWrappedPySchemaProvider { schema_provider: PyObject, owner_name: Option, } impl RustWrappedPySchemaProvider { - fn new(schema_provider: PyObject) -> Self { + pub fn new(schema_provider: PyObject) -> Self { let owner_name = Python::with_gil(|py| { schema_provider .bind(py) @@ -228,10 +302,14 @@ impl SchemaProvider for RustWrappedPySchemaProvider { fn table_names(&self) -> Vec { Python::with_gil(|py| { let provider = self.schema_provider.bind(py); + provider .getattr("table_names") .and_then(|names| names.extract::>()) - .unwrap_or_default() + .unwrap_or_else(|err| { + log::error!("Unable to get table_names: {err}"); + Vec::default() + }) }) } @@ -296,12 +374,12 @@ impl SchemaProvider for RustWrappedPySchemaProvider { } #[derive(Debug)] -struct RustWrappedPyCatalogProvider { - catalog_provider: PyObject, +pub(crate) struct RustWrappedPyCatalogProvider { + pub(crate) catalog_provider: PyObject, } impl RustWrappedPyCatalogProvider { - fn new(catalog_provider: PyObject) -> Self { + pub fn new(catalog_provider: PyObject) -> Self { Self { catalog_provider } } @@ -346,7 +424,10 @@ impl CatalogProvider for RustWrappedPyCatalogProvider { provider .getattr("schema_names") .and_then(|names| names.extract::>()) - .unwrap_or_default() + .unwrap_or_else(|err| { + log::error!("Unable to get schema_names: {err}"); + Vec::default() + }) }) } @@ -362,8 +443,19 @@ impl CatalogProvider for RustWrappedPyCatalogProvider { name: &str, schema: Arc, ) -> datafusion::common::Result>> { - let py_schema: PyDatabase = schema.into(); + // JRIGHT HERE + // let py_schema: PySchema = schema.into(); Python::with_gil(|py| { + let py_schema = match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(), + None => &PySchema::from(schema) + .into_py_any(py) + .map_err(to_datafusion_err)?, + }; + let provider = self.catalog_provider.bind(py); let schema = provider .call_method1("register_schema", (name, py_schema)) @@ -400,3 +492,11 @@ impl CatalogProvider for RustWrappedPyCatalogProvider { }) } } + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/context.rs b/src/context.rs index c57b6681..cb15c5f0 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -73,6 +73,7 @@ use datafusion::prelude::{ use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -621,7 +622,7 @@ impl PySessionContext { name: &str, provider: Bound<'_, PyAny>, ) -> PyDataFusionResult<()> { - if provider.hasattr("__datafusion_catalog_provider__")? { + let provider = if provider.hasattr("__datafusion_catalog_provider__")? { let capsule = provider .getattr("__datafusion_catalog_provider__")? .call0()?; @@ -630,29 +631,15 @@ impl PySessionContext { let provider = unsafe { capsule.reference::() }; let provider: ForeignCatalogProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPyCatalogProvider::new(provider.into()); + Arc::new(provider) as Arc + }; - let option: Option> = - self.ctx.register_catalog(name, Arc::new(provider)); - match option { - Some(existing) => { - println!( - "Catalog '{}' already existed, schema names: {:?}", - name, - existing.schema_names() - ); - } - None => { - println!("Catalog '{}' registered successfully", name); - } - } + let _ = self.ctx.register_catalog(name, provider); - Ok(()) - } else { - Err(crate::errors::PyDataFusionError::Common( - "__datafusion_catalog_provider__ does not exist on Catalog Provider object." - .to_string(), - )) - } + Ok(()) } /// Construct datafusion dataframe from Arrow Table @@ -886,14 +873,20 @@ impl PySessionContext { } #[pyo3(signature = (name="datafusion"))] - pub fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::from(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name, - ))), - } + pub fn catalog(&self, name: &str) -> PyResult { + let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( + "Catalog with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), + None => PyCatalog::from(catalog).into_py_any(py), + } + }) } pub fn tables(&self) -> HashSet { diff --git a/src/lib.rs b/src/lib.rs index 414d1c36..29d3f41d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,9 +81,6 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -101,6 +98,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + let catalog = PyModule::new(py, "catalog")?; + catalog::init_module(&catalog)?; + m.add_submodule(&catalog)?; + // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?; common::init_module(&common)?; From 0b3cc24479b808b477fba018b31973b78d6dbb30 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Jun 2025 12:39:00 -0400 Subject: [PATCH 06/18] Small updates after rebase --- python/datafusion/catalog.py | 24 ++++-------------------- src/catalog.rs | 2 +- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index cbea6bd6..bebd3816 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -50,26 +50,6 @@ def __repr__(self) -> str: """Print a string representation of the catalog.""" return self.catalog.__repr__() - def names(self) -> list[str]: - """Returns the list of databases in this catalog.""" - return self.catalog.names() - - def database(self, name: str = "public") -> Database: - """Returns the database with the given ``name`` from this catalog.""" - return Database(self.catalog.database(name)) - - -class Database: - """DataFusion Database.""" - - def __init__(self, db: df_internal.Database) -> None: - """This constructor is not typically called by the end user.""" - self.db = db - - def __repr__(self) -> str: - """Print a string representation of the database.""" - return self.db.__repr__() - def names(self) -> set[str]: """This is an alias for `schema_names`.""" return self.schema_names() @@ -109,6 +89,10 @@ def __init__(self, schema: df_internal.catalog.RawSchema) -> None: """This constructor is not typically called by the end user.""" self._raw_schema = schema + def __repr__(self) -> str: + """Print a string representation of the schema.""" + return self._raw_schema.__repr__() + def names(self) -> set[str]: """This is an alias for `table_names`.""" return self.table_names() diff --git a/src/catalog.rs b/src/catalog.rs index ab957fb2..ba96ce47 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -160,7 +160,7 @@ impl PySchema { } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))?? { + if let Some(table) = wait_for_future(py, self.schema.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( From 2b300b519988caa1fa59ec418642ca552449989f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Jun 2025 15:27:57 -0400 Subject: [PATCH 07/18] Add default in memory options for adding schema and catalogs --- python/datafusion/catalog.py | 5 +++++ python/datafusion/context.py | 9 +++++++++ src/catalog.rs | 11 +++++++++++ src/context.rs | 13 ++++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index bebd3816..5f1a317f 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -73,6 +73,11 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) + def new_in_memory_schema(self, name: str) -> Schema: + """Create a new schema in this catalog using an in-memory provider.""" + self.catalog.new_in_memory_schema(name) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: """Register a schema with this catalog.""" return self.catalog.register_schema(name, schema) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c080931e..f752272b 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -758,6 +758,15 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def catalog_names(self) -> set[str]: + """Returns the list of catalogs in this context.""" + return self.ctx.catalog_names() + + def new_in_memory_catalog(self, name: str) -> Catalog: + """Create a new catalog in this context using an in-memory provider.""" + self.ctx.new_in_memory_catalog(name) + return self.catalog(name) + def register_catalog_provider( self, name: str, provider: CatalogProviderExportable ) -> None: diff --git a/src/catalog.rs b/src/catalog.rs index ba96ce47..9a24f2d4 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -19,6 +19,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{validate_pycapsule, wait_for_future}; use async_trait::async_trait; +use datafusion::catalog::MemorySchemaProvider; use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, @@ -105,6 +106,16 @@ impl PyCatalog { }) } + fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> { + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let _ = self + .catalog + .register_schema(name, schema) + .map_err(py_datafusion_err)?; + + Ok(()) + } + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { let capsule = schema_provider diff --git a/src/context.rs b/src/context.rs index cb15c5f0..c97f2f61 100644 --- a/src/context.rs +++ b/src/context.rs @@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::CatalogProvider; +use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -617,6 +617,13 @@ impl PySessionContext { Ok(()) } + pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> { + let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc; + let _ = self.ctx.register_catalog(name, catalog); + + Ok(()) + } + pub fn register_catalog_provider( &mut self, name: &str, @@ -889,6 +896,10 @@ impl PySessionContext { }) } + pub fn catalog_names(&self) -> HashSet { + self.ctx.catalog_names().into_iter().collect() + } + pub fn tables(&self) -> HashSet { self.ctx .catalog_names() From 43d87a6aa50a045fe0039a42146dfb4746316e34 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 09:23:03 -0400 Subject: [PATCH 08/18] Add support for creating in memory catalog and schema --- python/datafusion/catalog.py | 19 ++++++++++++++----- python/datafusion/context.py | 12 +++++------- python/tests/test_catalog.py | 18 ++++++++++++++++++ src/catalog.rs | 35 ++++++++++++++++++++++------------- src/context.rs | 23 +++++++++++++---------- 5 files changed, 72 insertions(+), 35 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 5f1a317f..58abe830 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -58,6 +58,12 @@ def schema_names(self) -> set[str]: """Returns the list of schemas in this catalog.""" return self.catalog.schema_names() + @staticmethod + def memory_catalog() -> Catalog: + """Create an in-memory catalog provider.""" + catalog = df_internal.catalog.RawCatalog.memory_catalog() + return Catalog(catalog) + def schema(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" schema = self.catalog.schema(name) @@ -73,13 +79,10 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) - def new_in_memory_schema(self, name: str) -> Schema: - """Create a new schema in this catalog using an in-memory provider.""" - self.catalog.new_in_memory_schema(name) - return self.schema(name) - def register_schema(self, name, schema) -> Schema | None: """Register a schema with this catalog.""" + if isinstance(schema, Schema): + return self.catalog.register_schema(name, schema._raw_schema) return self.catalog.register_schema(name, schema) def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: @@ -98,6 +101,12 @@ def __repr__(self) -> str: """Print a string representation of the schema.""" return self._raw_schema.__repr__() + @staticmethod + def memory_schema() -> Schema: + """Create an in-memory schema provider.""" + schema = df_internal.catalog.RawSchema.memory_schema() + return Schema(schema) + def names(self) -> set[str]: """This is an alias for `table_names`.""" return self.table_names() diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f752272b..c652d4e8 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -762,16 +762,14 @@ def catalog_names(self) -> set[str]: """Returns the list of catalogs in this context.""" return self.ctx.catalog_names() - def new_in_memory_catalog(self, name: str) -> Catalog: - """Create a new catalog in this context using an in-memory provider.""" - self.ctx.new_in_memory_catalog(name) - return self.catalog(name) - def register_catalog_provider( - self, name: str, provider: CatalogProviderExportable + self, name: str, provider: CatalogProviderExportable | Catalog ) -> None: """Register a catalog provider.""" - self.ctx.register_catalog_provider(name, provider) + if isinstance(provider, Catalog): + self.ctx.register_catalog_provider(name, provider.catalog) + else: + self.ctx.register_catalog_provider(name, provider) def register_table_provider( self, name: str, provider: TableProviderExportable diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 21b0a3e0..264bdaa9 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -106,6 +106,24 @@ def test_python_catalog_provider(ctx: SessionContext): assert my_catalog.schema_names() == {"second_schema"} +def test_in_memory_providers(ctx: SessionContext): + catalog = dfn.catalog.Catalog.memory_catalog() + ctx.register_catalog_provider("in_mem_catalog", catalog) + + assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"} + + schema = dfn.catalog.Schema.memory_schema() + catalog.register_schema("in_mem_schema", schema) + + schema.register_table("my_table", create_dataset()) + + batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([1, 2, 3]) + assert batches[0].column(1) == pa.array([4, 5, 6]) + + def test_python_schema_provider(ctx: SessionContext): catalog = ctx.catalog() diff --git a/src/catalog.rs b/src/catalog.rs index 9a24f2d4..d85e4069 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -19,7 +19,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{validate_pycapsule, wait_for_future}; use async_trait::async_trait; -use datafusion::catalog::MemorySchemaProvider; +use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, @@ -37,16 +37,19 @@ use std::collections::HashSet; use std::sync::Arc; #[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] +#[derive(Clone)] pub struct PyCatalog { pub catalog: Arc, } #[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +#[derive(Clone)] pub struct PySchema { pub schema: Arc, } #[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] +#[derive(Clone)] pub struct PyTable { pub table: Arc, } @@ -82,6 +85,13 @@ impl PyCatalog { catalog_provider.into() } + #[staticmethod] + fn memory_catalog() -> Self { + let catalog_provider = + Arc::new(MemoryCatalogProvider::default()) as Arc; + catalog_provider.into() + } + fn schema_names(&self) -> HashSet { self.catalog.schema_names().into_iter().collect() } @@ -106,16 +116,6 @@ impl PyCatalog { }) } - fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> { - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let _ = self - .catalog - .register_schema(name, schema) - .map_err(py_datafusion_err)?; - - Ok(()) - } - fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { let capsule = schema_provider @@ -128,8 +128,11 @@ impl PyCatalog { let provider: ForeignSchemaProvider = provider.into(); Arc::new(provider) as Arc } else { - let provider = RustWrappedPySchemaProvider::new(schema_provider.into()); - Arc::new(provider) as Arc + match schema_provider.extract::() { + Ok(py_schema) => py_schema.schema, + Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into())) + as Arc, + } }; let _ = self @@ -165,6 +168,12 @@ impl PySchema { schema_provider.into() } + #[staticmethod] + fn memory_schema() -> Self { + let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc; + schema_provider.into() + } + #[getter] fn table_names(&self) -> HashSet { self.schema.table_names().into_iter().collect() diff --git a/src/context.rs b/src/context.rs index c97f2f61..d4879e85 100644 --- a/src/context.rs +++ b/src/context.rs @@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -617,13 +617,6 @@ impl PySessionContext { Ok(()) } - pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> { - let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc; - let _ = self.ctx.register_catalog(name, catalog); - - Ok(()) - } - pub fn register_catalog_provider( &mut self, name: &str, @@ -640,8 +633,18 @@ impl PySessionContext { let provider: ForeignCatalogProvider = provider.into(); Arc::new(provider) as Arc } else { - let provider = RustWrappedPyCatalogProvider::new(provider.into()); - Arc::new(provider) as Arc + println!("Provider has type {}", provider.get_type()); + match provider.extract::() { + Ok(py_catalog) => { + println!("registering an existing PyCatalog"); + py_catalog.catalog + } + Err(_) => { + println!("registering a rust wrapped catalog provider"); + Arc::new(RustWrappedPyCatalogProvider::new(provider.into())) + as Arc + } + } }; let _ = self.ctx.register_catalog(name, provider); From 364f2c720b366f8e0ed418b5ec755f992fc6e880 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 09:46:43 -0400 Subject: [PATCH 09/18] Update from database to schema in unit tests --- python/tests/test_catalog.py | 2 +- python/tests/test_context.py | 40 +++++++++++++++++----------------- python/tests/test_sql.py | 14 ++++++------ python/tests/test_substrait.py | 4 ++-- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 264bdaa9..045ddbc0 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -32,7 +32,7 @@ def test_basic(ctx, database): default = ctx.catalog() assert default.names() == {"public"} - for db in [default.database("public"), default.database()]: + for db in [default.schema("public"), default.schema()]: assert db.names() == {"csv1", "csv", "csv2"} table = db.table("csv") diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 4a15ac9c..6dbcc0d5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -57,7 +57,7 @@ def test_runtime_configs(tmp_path, path_to_str): ctx = SessionContext(config, runtime) assert ctx is not None - db = ctx.catalog("foo").database("bar") + db = ctx.catalog("foo").schema("bar") assert db is not None @@ -70,7 +70,7 @@ def test_temporary_files(tmp_path, path_to_str): ctx = SessionContext(config, runtime) assert ctx is not None - db = ctx.catalog("foo").database("bar") + db = ctx.catalog("foo").schema("bar") assert db is not None @@ -91,7 +91,7 @@ def test_create_context_with_all_valid_args(): ctx = SessionContext(config, runtime) # verify that at least some of the arguments worked - ctx.catalog("foo").database("bar") + ctx.catalog("foo").schema("bar") with pytest.raises(KeyError): ctx.catalog("datafusion") @@ -105,7 +105,7 @@ def test_register_record_batches(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -121,7 +121,7 @@ def test_create_dataframe_registers_unique_table_name(ctx): ) df = ctx.create_dataframe([[batch]]) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -141,7 +141,7 @@ def test_create_dataframe_registers_with_defined_table_name(ctx): ) df = ctx.create_dataframe([[batch]], name="tbl") - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -155,7 +155,7 @@ def test_from_arrow_table(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -200,7 +200,7 @@ def test_from_arrow_table_with_name(ctx): # convert to DataFrame with optional name df = ctx.from_arrow(table, name="tbl") - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert tables[0] == "tbl" @@ -213,7 +213,7 @@ def test_from_arrow_table_empty(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -228,7 +228,7 @@ def test_from_arrow_table_empty_no_schema(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -246,7 +246,7 @@ def test_from_pylist(ctx): ] df = ctx.from_pylist(data) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -260,7 +260,7 @@ def test_from_pydict(ctx): data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = ctx.from_pydict(data) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -276,7 +276,7 @@ def test_from_pandas(ctx): pandas_df = pd.DataFrame(data) df = ctx.from_pandas(pandas_df) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -292,7 +292,7 @@ def test_from_polars(ctx): polars_df = pd.DataFrame(data) df = ctx.from_polars(polars_df) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -303,7 +303,7 @@ def test_from_polars(ctx): def test_register_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") @@ -313,7 +313,7 @@ def test_register_table(ctx, database): def test_read_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") @@ -323,7 +323,7 @@ def test_read_table(ctx, database): def test_deregister_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} ctx.deregister_table("csv") @@ -339,7 +339,7 @@ def test_register_dataset(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -356,7 +356,7 @@ def test_dataset_filter(ctx, capfd): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") # Make sure the filter was pushed down in Physical Plan @@ -455,7 +455,7 @@ def test_dataset_filter_nested_data(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} df = ctx.table("t") diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 41cee4ef..b8a14a6e 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -75,7 +75,7 @@ def test_register_csv(ctx, tmp_path): ) ctx.register_csv("csv3", path, schema=alternative_schema) - assert ctx.catalog().database().names() == { + assert ctx.catalog().schema().names() == { "csv", "csv1", "csv2", @@ -150,7 +150,7 @@ def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) ctx.register_parquet("t1", str(path)) - assert ctx.catalog().database().names() == {"t", "t1"} + assert ctx.catalog().schema().names() == {"t", "t1"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -188,7 +188,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_ty parquet_pruning=True, file_extension=".parquet", ) - assert ctx.catalog().database().names() == {"datapp"} + assert ctx.catalog().schema().names() == {"datapp"} result = ctx.sql("SELECT grp, COUNT(*) AS cnt FROM datapp GROUP BY grp").collect() result = pa.Table.from_batches(result) @@ -204,7 +204,7 @@ def test_register_dataset(ctx, tmp_path, path_to_str): dataset = ds.dataset(path, format="parquet") ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -251,7 +251,7 @@ def test_register_json(ctx, tmp_path): ) ctx.register_json("json3", path, schema=alternative_schema) - assert ctx.catalog().database().names() == { + assert ctx.catalog().schema().names() == { "json", "json1", "json2", @@ -308,7 +308,7 @@ def test_execute(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) ctx.register_parquet("t", path) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} # count result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").collect() @@ -524,7 +524,7 @@ def test_register_listing_table( schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, ) - assert ctx.catalog().database().names() == {"my_table"} + assert ctx.catalog().schema().names() == {"my_table"} result = ctx.sql( "SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp" diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py index f367a447..43aa327d 100644 --- a/python/tests/test_substrait.py +++ b/python/tests/test_substrait.py @@ -34,7 +34,7 @@ def test_substrait_serialization(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} # For now just make sure the method calls blow up substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM t", ctx) @@ -59,7 +59,7 @@ def test_substrait_file_serialization(ctx, tmp_path, path_to_str): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} path = tmp_path / "substrait_plan" path = str(path) if path_to_str else path From e2b6aa2c48ae6d6bac15441bb69f2c6810a0a9d8 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 09:51:54 -0400 Subject: [PATCH 10/18] xfailed label no longer applies to these unit tests --- python/tests/test_sql.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index b8a14a6e..c383edc6 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -451,18 +451,10 @@ def test_udf( id="datetime_ns", ), # Not writtable to parquet - pytest.param( - helpers.data_timedelta("s"), id="timedelta_s", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("ms"), id="timedelta_ms", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("us"), id="timedelta_us", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("ns"), id="timedelta_ns", marks=pytest.mark.xfail - ), + pytest.param(helpers.data_timedelta("s"), id="timedelta_s"), + pytest.param(helpers.data_timedelta("ms"), id="timedelta_ms"), + pytest.param(helpers.data_timedelta("us"), id="timedelta_us"), + pytest.param(helpers.data_timedelta("ns"), id="timedelta_ns"), ], ) def test_simple_select(ctx, tmp_path, arr): From 294a8a9b1e8ee62b8660143b0202902ac9017121 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 12:59:40 -0400 Subject: [PATCH 11/18] Defining abstract methods for catalog and schema providers --- python/datafusion/catalog.py | 76 ++++++++++++++++++++++++++++++++++++ python/tests/test_catalog.py | 21 +++++----- src/catalog.rs | 19 +++++++-- src/context.rs | 13 ++---- 4 files changed, 107 insertions(+), 22 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 58abe830..9e3af3ac 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -19,6 +19,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import TYPE_CHECKING import datafusion._internal as df_internal @@ -121,6 +122,8 @@ def table(self, name: str) -> Table: def register_table(self, name, table) -> None: """Register a table provider in this schema.""" + if isinstance(table, Table): + return self._raw_schema.register_table(name, table.table) return self._raw_schema.register_table(name, table) def deregister_table(self, name: str) -> None: @@ -144,6 +147,11 @@ def __repr__(self) -> str: """Print a string representation of the table.""" return self.table.__repr__() + @staticmethod + def from_dataset(dataset: pa.dataset.Dataset) -> Table: + """Turn a pyarrow Dataset into a Table.""" + return Table(df_internal.catalog.RawTable.from_dataset(dataset)) + @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" @@ -153,3 +161,71 @@ def schema(self) -> pa.Schema: def kind(self) -> str: """Returns the kind of table.""" return self.table.kind + + +class CatalogProvider(ABC): + @abstractmethod + def schema_names(self) -> set[str]: + """Set of the names of all schemas in this catalog.""" + ... + + @abstractmethod + def schema(self, name: str) -> Schema | None: + """Retrieve a specific schema from this catalog.""" + ... + + def register_schema(self, name: str, schema: Schema) -> None: # noqa: B027 + """Add a schema to this catalog. + + This method is optional. If your catalog provides a fixed list of schemas, you + do not need to implement this method. + """ + + def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027 + """Remove a schema from this catalog. + + This method is optional. If your catalog provides a fixed list of schemas, you + do not need to implement this method. + + Args: + name: The name of the schema to remove. + cascade: If true, deregister the tables within the schema. + """ + + +class SchemaProvider(ABC): + def owner_name(self) -> str | None: + """Returns the owner of the schema. + + This is an optional method. The default return is None. + """ + return None + + @abstractmethod + def table_names(self) -> set[str]: + """Set of the names of all tables in this schema.""" + ... + + @abstractmethod + def table(self, name: str) -> Table | None: + """Retrieve a specific table from this schema.""" + ... + + def register_table(self, name: str, table: Table) -> None: # noqa: B027 + """Add a table from this schema. + + This method is optional. If your schema provides a fixed list of tables, you do + not need to implement this method. + """ + + def deregister_table(self, name, cascade: bool) -> None: # noqa: B027 + """Remove a table from this schema. + + This method is optional. If your schema provides a fixed list of tables, you do + not need to implement this method. + """ + + @abstractmethod + def table_exist(self, name: str) -> bool: + """Returns true if the table exists in this schema.""" + ... diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 045ddbc0..1ee1ae09 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations import datafusion as dfn import pyarrow as pa @@ -46,20 +47,16 @@ def test_basic(ctx, database): ) -class CustomTableProvider: - def __init__(self): - pass - - -def create_dataset() -> pa.dataset.Dataset: +def create_dataset() -> Table: batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) - return ds.dataset([batch]) + dataset = ds.dataset([batch]) + return Table.from_dataset(dataset) -class CustomSchemaProvider: +class CustomSchemaProvider(dfn.catalog.SchemaProvider): def __init__(self): self.tables = {"table1": create_dataset()} @@ -72,8 +69,14 @@ def register_table(self, name: str, table: Table): def deregister_table(self, name, cascade: bool = True): del self.tables[name] + def table(self, name: str) -> Table | None: + return self.tables[name] + + def table_exist(self, name: str) -> bool: + return name in self.tables + -class CustomCatalogProvider: +class CustomCatalogProvider(dfn.catalog.CatalogProvider): def __init__(self): self.schemas = {"my_schema": CustomSchemaProvider()} diff --git a/src/catalog.rs b/src/catalog.rs index d85e4069..aef0e33e 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -207,9 +207,14 @@ impl PySchema { let provider: ForeignTableProvider = provider.into(); Arc::new(provider) as Arc } else { - let py = table_provider.py(); - let provider = Dataset::new(&table_provider, py)?; - Arc::new(provider) as Arc + match table_provider.extract::() { + Ok(py_table) => py_table.table, + Err(_) => { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + } + } }; let _ = self @@ -238,6 +243,14 @@ impl PyTable { self.table.schema().to_pyarrow(py) } + #[staticmethod] + fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult { + let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?) + as Arc; + + Ok(Self::new(ds)) + } + /// Get the type of this table for metadata/catalog purposes. #[getter] fn kind(&self) -> &str { diff --git a/src/context.rs b/src/context.rs index d4879e85..1d3dfc71 100644 --- a/src/context.rs +++ b/src/context.rs @@ -633,17 +633,10 @@ impl PySessionContext { let provider: ForeignCatalogProvider = provider.into(); Arc::new(provider) as Arc } else { - println!("Provider has type {}", provider.get_type()); match provider.extract::() { - Ok(py_catalog) => { - println!("registering an existing PyCatalog"); - py_catalog.catalog - } - Err(_) => { - println!("registering a rust wrapped catalog provider"); - Arc::new(RustWrappedPyCatalogProvider::new(provider.into())) - as Arc - } + Ok(py_catalog) => py_catalog.catalog, + Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into())) + as Arc, } }; From 083f7c4e4277795dc19c1b68406f10cf8d819f8f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 13:49:48 -0400 Subject: [PATCH 12/18] Working through issues between custom catalog and build in schema --- python/datafusion/catalog.py | 15 +++++++++-- python/datafusion/context.py | 4 +-- python/tests/test_catalog.py | 48 ++++++++++++++++++++++++++++++++++++ src/catalog.rs | 14 +++++++++-- 4 files changed, 75 insertions(+), 6 deletions(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 9e3af3ac..d8755ac9 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -20,7 +20,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol import datafusion._internal as df_internal @@ -174,7 +174,9 @@ def schema(self, name: str) -> Schema | None: """Retrieve a specific schema from this catalog.""" ... - def register_schema(self, name: str, schema: Schema) -> None: # noqa: B027 + def register_schema( # noqa: B027 + self, name: str, schema: SchemaProviderExportable | SchemaProvider | Schema + ) -> None: """Add a schema to this catalog. This method is optional. If your catalog provides a fixed list of schemas, you @@ -229,3 +231,12 @@ def deregister_table(self, name, cascade: bool) -> None: # noqa: B027 def table_exist(self, name: str) -> bool: """Returns true if the table exists in this schema.""" ... + + +class SchemaProviderExportable(Protocol): + """Type hint for object that has __datafusion_schema_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html + """ + + def __datafusion_schema_provider__(self) -> object: ... diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c652d4e8..bce51d64 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -29,7 +29,7 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.catalog import Catalog, Table +from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream @@ -763,7 +763,7 @@ def catalog_names(self) -> set[str]: return self.ctx.catalog_names() def register_catalog_provider( - self, name: str, provider: CatalogProviderExportable | Catalog + self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog ) -> None: """Register a catalog provider.""" if isinstance(provider, Catalog): diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 1ee1ae09..1f9ecbfc 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -162,3 +162,51 @@ def test_python_table_provider(ctx: SessionContext): schema.deregister_table("table3") schema.register_table("table4", create_dataset()) assert schema.table_names() == {"table4"} + + +def test_in_end_to_end_python_providers(ctx: SessionContext): + """Test registering all python providers and running a query against them.""" + + all_catalog_names = [ + "datafusion", + "custom_catalog", + "in_mem_catalog", + ] + + all_schema_names = [ + "custom_schema", + "in_mem_schema", + ] + + ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider()) + ctx.register_catalog_provider( + all_catalog_names[2], dfn.catalog.Catalog.memory_catalog() + ) + + for catalog_name in all_catalog_names: + catalog = ctx.catalog(catalog_name) + + # Clean out previous schemas if they exist so we can start clean + for schema_name in catalog.schema_names(): + catalog.deregister_schema(schema_name, cascade=False) + + catalog.register_schema(all_schema_names[0], CustomSchemaProvider()) + catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema()) + + for schema_name in all_schema_names: + schema = catalog.schema(schema_name) + + for table_name in schema.table_names(): + schema.deregister_table(table_name) + + schema.register_table("test_table", create_dataset()) + + for catalog_name in all_catalog_names: + for schema_name in all_schema_names: + table_full_name = f"{catalog_name}.{schema_name}.test_table" + + batches = ctx.sql(f"select * from {table_full_name}").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([1, 2, 3]) + assert batches[0].column(1) == pa.array([4, 5, 6]) diff --git a/src/catalog.rs b/src/catalog.rs index aef0e33e..74c50a69 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -314,9 +314,19 @@ impl RustWrappedPySchemaProvider { Ok(Some(Arc::new(provider) as Arc)) } else { - let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + if let Ok(inner_table) = py_table.getattr("table") { + if let Ok(inner_table) = inner_table.extract::() { + return Ok(Some(inner_table.table)); + } + } - Ok(Some(Arc::new(ds) as Arc)) + match py_table.extract::() { + Ok(py_table) => Ok(Some(py_table.table)), + Err(_) => { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + Ok(Some(Arc::new(ds) as Arc)) + } + } } }) } From a4b8b21e2220dadddf67c539bae789f49b86e8a4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 13:53:38 -0400 Subject: [PATCH 13/18] Check types on schema provider to return --- src/catalog.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 74c50a69..17d4ec3b 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -447,9 +447,19 @@ impl RustWrappedPyCatalogProvider { Ok(Some(Arc::new(provider) as Arc)) } else { - let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + if let Ok(inner_schema) = py_schema.getattr("schema") { + if let Ok(inner_schema) = inner_schema.extract::() { + return Ok(Some(inner_schema.schema)); + } + } + match py_schema.extract::() { + Ok(inner_schema) => Ok(Some(inner_schema.schema)), + Err(_) => { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); - Ok(Some(Arc::new(py_schema) as Arc)) + Ok(Some(Arc::new(py_schema) as Arc)) + } + } } }) } From 55dd215b1dfef96fc09ea70d7fbe00a7f28b462c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 16:04:46 -0400 Subject: [PATCH 14/18] Add docstring --- python/datafusion/catalog.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index d8755ac9..3d3e532b 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -164,6 +164,8 @@ def kind(self) -> str: class CatalogProvider(ABC): + """Abstract class for defining a Python based Catalog Provider.""" + @abstractmethod def schema_names(self) -> set[str]: """Set of the names of all schemas in this catalog.""" @@ -196,6 +198,8 @@ def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027 class SchemaProvider(ABC): + """Abstract class for defining a Python based Schema Provider.""" + def owner_name(self) -> str | None: """Returns the owner of the schema. From b08121ab26d188e7709a2514f0413f31333b430c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 19 Jun 2025 16:39:09 -0400 Subject: [PATCH 15/18] Add documentation about how to use catalog and schema providers --- docs/source/user-guide/data-sources.rst | 56 +++++++++++++++++++++++++ python/datafusion/catalog.py | 2 + 2 files changed, 58 insertions(+) diff --git a/docs/source/user-guide/data-sources.rst b/docs/source/user-guide/data-sources.rst index ba5967c9..9c95d58e 100644 --- a/docs/source/user-guide/data-sources.rst +++ b/docs/source/user-guide/data-sources.rst @@ -185,3 +185,59 @@ the interface as describe in the :ref:`Custom Table Provider `_ is provided in the DataFusion repository. + +Catalog +======= + +A common technique for organizing tables is using a three level hierarchical approach. DataFusion +supports this form of organizing using the :py:class:`~datafusion.catalog.Catalog`, +:py:class:`~datafusion.catalog.Schema`, and :py:class:`~datafusion.catalog.Table`. By default, +a :py:class:`~datafusion.context.SessionContext` comes with a single Catalog and a single Schema +with the names ``datafusion`` and ``default``, respectively. + +The default implementation uses an in-memory approach to the catalog and schema. We have support +for adding additional in-memory catalogs and schemas. This can be done like in the following +example: + +.. code-block:: python + + from datafusion.catalog import Catalog, Schema + + my_catalog = Catalog.memory_catalog() + my_schema = Schema.memory_schema() + + my_catalog.register_schema("my_schema_name", my_schema) + + ctx.register_catalog("my_catalog_name", my_catalog) + +You could then register tables in ``my_schema`` and access them either through the DataFrame +API or via sql commands such as ``"SELECT * from my_catalog_name.my_schema_name.my_table"``. + +User Defined Catalog and Schema +------------------------------- + +If the in-memory catalogs are insufficient for your uses, there are two approaches you can take +to implementing a custom catalog and/or schema. In the below discussion, we describe how to +implement these for a Catalog, but the approach to implementing for a Schema is nearly +identical. + +DataFusion supports Catalogs written in either Rust or Python. If you write a Catalog in Rust, +you will need to export it as a Python library via PyO3. There is a complete example of a +catalog implemented this way in the +`examples folder `_ +of our repository. Writing catalog providers in Rust provides typically can lead to significant +performance improvements over the Python based approach. + +To implement a Catalog in Python, you will need to inherit from the abstract base class +:py:class:`~datafusion.catalog.CatalogProvider`. There are examples in the +`unit tests `_ of +implementing a basic Catalog in Python where we simply keep a dictionary of the +registered Schemas. + +One important note for developers is that when we have a Catalog defined in Python, we have +two different ways of accessing this Catalog. First, we register the catalog with a Rust +wrapper. This allows for any rust based code to call the Python functions as necessary. +Second, if the user access the Catalog via the Python API, we identify this and return back +the original Python object that implements the Catalog. This is an important distinction +for developers because we do *not* return a Python wrapper around the Rust wrapper of the +original Python object. diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 3d3e532b..536b3a79 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -35,7 +35,9 @@ __all__ = [ "Catalog", + "CatalogProvider", "Schema", + "SchemaProvider", "Table", ] From a6baba158fe75a26cad70c4362b5927f733df3ee Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Jul 2025 07:08:10 -0400 Subject: [PATCH 16/18] Re-add module to all after rebase --- python/datafusion/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index fd7f4fc0..e9d2dba7 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -92,6 +92,7 @@ "TableFunction", "WindowFrame", "WindowUDF", + "catalog", "col", "column", "common", From eb21c9bcdced9c60bc913248283702bea9c70ec5 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Jul 2025 07:24:06 -0400 Subject: [PATCH 17/18] Minor bugfix --- python/datafusion/dataframe.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 991e6875..61cb0943 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -760,19 +760,16 @@ def join_on( exprs = [expr.expr for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) - def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: - """Return a DataFrame with the explanation of its plan so far. + def explain(self, verbose: bool = False, analyze: bool = False) -> None: + """Print an explanation of the DataFrame's plan so far. If ``analyze`` is specified, runs the plan and reports metrics. Args: verbose: If ``True``, more details will be included. analyze: If ``Tru`e``, the plan will run and metrics reported. - - Returns: - DataFrame with the explanation of its plan. """ - return DataFrame(self.df.explain(verbose, analyze)) + self.df.explain(verbose, analyze) def logical_plan(self) -> LogicalPlan: """Return the unoptimized ``LogicalPlan``. From e0f352756f53a04de4072ce48bbfff56229ab935 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Jul 2025 07:33:58 -0400 Subject: [PATCH 18/18] Clippy updates from the new rust version --- src/common/data_type.rs | 120 +++++++++++++----------------- src/context.rs | 2 +- src/expr.rs | 15 ++-- src/expr/aggregate.rs | 2 +- src/expr/aggregate_expr.rs | 2 +- src/expr/alias.rs | 2 +- src/expr/analyze.rs | 2 +- src/expr/between.rs | 2 +- src/expr/column.rs | 2 +- src/expr/copy_to.rs | 4 +- src/expr/create_catalog.rs | 2 +- src/expr/create_catalog_schema.rs | 2 +- src/expr/create_external_table.rs | 2 +- src/expr/create_function.rs | 2 +- src/expr/create_index.rs | 2 +- src/expr/create_memory_table.rs | 2 +- src/expr/create_view.rs | 2 +- src/expr/describe_table.rs | 2 +- src/expr/distinct.rs | 5 +- src/expr/drop_catalog_schema.rs | 2 +- src/expr/drop_function.rs | 2 +- src/expr/drop_table.rs | 2 +- src/expr/drop_view.rs | 2 +- src/expr/empty_relation.rs | 2 +- src/expr/filter.rs | 2 +- src/expr/join.rs | 2 +- src/expr/like.rs | 6 +- src/expr/limit.rs | 2 +- src/expr/projection.rs | 2 +- src/expr/recursive_query.rs | 2 +- src/expr/repartition.rs | 2 +- src/expr/sort.rs | 2 +- src/expr/sort_expr.rs | 2 +- src/expr/subquery.rs | 2 +- src/expr/subquery_alias.rs | 2 +- src/expr/table_scan.rs | 2 +- src/expr/union.rs | 2 +- src/expr/unnest.rs | 2 +- src/expr/unnest_expr.rs | 2 +- src/expr/window.rs | 11 +-- src/physical_plan.rs | 3 +- src/sql/logical.rs | 3 +- src/utils.rs | 5 +- 43 files changed, 104 insertions(+), 136 deletions(-) diff --git a/src/common/data_type.rs b/src/common/data_type.rs index f5f8a6b0..5cf9d6e9 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -172,7 +172,7 @@ impl DataTypeMap { SqlType::DATE, )), DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Interval(interval_unit) => Ok(DataTypeMap::new( DataType::Interval(*interval_unit), @@ -189,7 +189,7 @@ impl DataTypeMap { SqlType::BINARY, )), DataType::FixedSizeBinary(_) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::LargeBinary => Ok(DataTypeMap::new( DataType::LargeBinary, @@ -207,23 +207,22 @@ impl DataTypeMap { SqlType::VARCHAR, )), DataType::List(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - arrow_type + "{arrow_type:?}" )))), DataType::FixedSizeList(_, _) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::LargeList(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Struct(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Union(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new( DataType::Decimal128(*precision, *scale), @@ -236,23 +235,22 @@ impl DataTypeMap { SqlType::DECIMAL, )), DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::RunEndEncoded(_, _) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::BinaryView => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Utf8View => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - arrow_type + "{arrow_type:?}" )))), DataType::ListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::LargeListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), } } @@ -379,8 +377,7 @@ impl DataTypeMap { "double" => Ok(DataType::Float64), "byte_array" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( - "Unable to determine Arrow Data Type from Parquet String type: {:?}", - parquet_str_type + "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}" ))), }; DataTypeMap::map_from_arrow_type(&arrow_dtype?) @@ -404,12 +401,10 @@ impl DataTypeMap { pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult { match sql_type { SqlType::ANY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::ARRAY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::BIGINT => Ok(DataTypeMap::new( DataType::Int64, @@ -432,11 +427,10 @@ impl DataTypeMap { SqlType::CHAR, )), SqlType::COLUMN_LIST => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::CURSOR => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::DATE => Ok(DataTypeMap::new( DataType::Date64, @@ -449,8 +443,7 @@ impl DataTypeMap { SqlType::DECIMAL, )), SqlType::DISTINCT => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::DOUBLE => Ok(DataTypeMap::new( DataType::Decimal256(1, 1), @@ -458,7 +451,7 @@ impl DataTypeMap { SqlType::DOUBLE, )), SqlType::DYNAMIC_STAR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::FLOAT => Ok(DataTypeMap::new( DataType::Decimal128(1, 1), @@ -466,8 +459,7 @@ impl DataTypeMap { SqlType::FLOAT, )), SqlType::GEOMETRY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::INTEGER => Ok(DataTypeMap::new( DataType::Int8, @@ -475,55 +467,52 @@ impl DataTypeMap { SqlType::INTEGER, )), SqlType::INTERVAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::INTERVAL_DAY => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_DAY_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_DAY_MINUTE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_DAY_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_HOUR_MINUTE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_HOUR_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_MINUTE => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_MINUTE_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_MONTH => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_SECOND => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_YEAR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_YEAR_MONTH => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::MAP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::MULTISET => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::NULL => Ok(DataTypeMap::new( DataType::Null, @@ -531,20 +520,16 @@ impl DataTypeMap { SqlType::NULL, )), SqlType::OTHER => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::REAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::ROW => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::SARG => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::SMALLINT => Ok(DataTypeMap::new( DataType::Int16, @@ -552,25 +537,22 @@ impl DataTypeMap { SqlType::SMALLINT, )), SqlType::STRUCTURED => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::SYMBOL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIME => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIME_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::TIMESTAMP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::TINYINT => Ok(DataTypeMap::new( DataType::Int8, @@ -578,8 +560,7 @@ impl DataTypeMap { SqlType::TINYINT, )), SqlType::UNKNOWN => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::VARBINARY => Ok(DataTypeMap::new( DataType::LargeBinary, @@ -682,8 +663,7 @@ impl PyDataType { "datetime64" => Ok(DataType::Date64), "object" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( - "Unable to determine Arrow Data Type from Arrow String type: {:?}", - arrow_str_type + "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}" ))), }; Ok(PyDataType { diff --git a/src/context.rs b/src/context.rs index 1d3dfc71..36133a33 100644 --- a/src/context.rs +++ b/src/context.rs @@ -368,7 +368,7 @@ impl PySessionContext { } else { &upstream_host }; - let url_string = format!("{}{}", scheme, derived_host); + let url_string = format!("{scheme}{derived_host}"); let url = Url::parse(&url_string).unwrap(); self.ctx.runtime_env().register_object_store(&url, store); Ok(()) diff --git a/src/expr.rs b/src/expr.rs index 6b1d01d6..f1e00236 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -171,12 +171,10 @@ impl PyExpr { Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_bound_py_any(py)?), Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_bound_py_any(py)?), Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!( - "Converting Expr::ScalarFunction to a Python object is not implemented: {:?}", - value + "Converting Expr::ScalarFunction to a Python object is not implemented: {value:?}" ))), Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!( - "Converting Expr::WindowFunction to a Python object is not implemented: {:?}", - value + "Converting Expr::WindowFunction to a Python object is not implemented: {value:?}" ))), Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_bound_py_any(py)?), Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_bound_py_any(py)?), @@ -188,8 +186,7 @@ impl PyExpr { } #[allow(deprecated)] Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!( - "Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}", - qualifier, options + "Converting Expr::Wildcard to a Python object is not implemented : {qualifier:?} {options:?}" ))), Expr::GroupingSet(value) => { Ok(grouping_set::PyGroupingSet::from(value.clone()).into_bound_py_any(py)?) @@ -198,8 +195,7 @@ impl PyExpr { Ok(placeholder::PyPlaceholder::from(value.clone()).into_bound_py_any(py)?) } Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!( - "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}", - data_type, column + "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {data_type:?} - {column:?}" ))), Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_bound_py_any(py)?), } @@ -755,8 +751,7 @@ impl PyExpr { Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type), Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value), _ => Err(py_type_err(format!( - "Non Expr::Literal encountered in types: {:?}", - expr + "Non Expr::Literal encountered in types: {expr:?}" ))), } } diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index a99d83d2..fd439327 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -116,7 +116,7 @@ impl PyAggregate { } fn __repr__(&self) -> PyResult { - Ok(format!("Aggregate({})", self)) + Ok(format!("Aggregate({self})")) } } diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index c09f116e..7c5d3d31 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -75,6 +75,6 @@ impl PyAggregateFunction { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/alias.rs b/src/expr/alias.rs index e8e03cfa..40746f20 100644 --- a/src/expr/alias.rs +++ b/src/expr/alias.rs @@ -64,6 +64,6 @@ impl PyAlias { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/analyze.rs b/src/expr/analyze.rs index 62f93cd2..e8081e95 100644 --- a/src/expr/analyze.rs +++ b/src/expr/analyze.rs @@ -69,7 +69,7 @@ impl PyAnalyze { } fn __repr__(&self) -> PyResult { - Ok(format!("Analyze({})", self)) + Ok(format!("Analyze({self})")) } } diff --git a/src/expr/between.rs b/src/expr/between.rs index a2cac144..817f1baa 100644 --- a/src/expr/between.rs +++ b/src/expr/between.rs @@ -71,6 +71,6 @@ impl PyBetween { } fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/column.rs b/src/expr/column.rs index 365dbc0d..50f316f1 100644 --- a/src/expr/column.rs +++ b/src/expr/column.rs @@ -45,7 +45,7 @@ impl PyColumn { /// Get the column relation fn relation(&self) -> Option { - self.col.relation.as_ref().map(|r| format!("{}", r)) + self.col.relation.as_ref().map(|r| format!("{r}")) } /// Get the fully-qualified column name diff --git a/src/expr/copy_to.rs b/src/expr/copy_to.rs index ebfcb8eb..473dabfe 100644 --- a/src/expr/copy_to.rs +++ b/src/expr/copy_to.rs @@ -106,7 +106,7 @@ impl PyCopyTo { } fn __repr__(&self) -> PyResult { - Ok(format!("CopyTo({})", self)) + Ok(format!("CopyTo({self})")) } fn __name__(&self) -> PyResult { @@ -129,7 +129,7 @@ impl Display for PyFileType { #[pymethods] impl PyFileType { fn __repr__(&self) -> PyResult { - Ok(format!("FileType({})", self)) + Ok(format!("FileType({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_catalog.rs b/src/expr/create_catalog.rs index f4ea0f51..d2d2ee8f 100644 --- a/src/expr/create_catalog.rs +++ b/src/expr/create_catalog.rs @@ -81,7 +81,7 @@ impl PyCreateCatalog { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateCatalog({})", self)) + Ok(format!("CreateCatalog({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_catalog_schema.rs b/src/expr/create_catalog_schema.rs index 85f447e1..e794962f 100644 --- a/src/expr/create_catalog_schema.rs +++ b/src/expr/create_catalog_schema.rs @@ -81,7 +81,7 @@ impl PyCreateCatalogSchema { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateCatalogSchema({})", self)) + Ok(format!("CreateCatalogSchema({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_external_table.rs b/src/expr/create_external_table.rs index 01ce7d0c..3e35af00 100644 --- a/src/expr/create_external_table.rs +++ b/src/expr/create_external_table.rs @@ -164,7 +164,7 @@ impl PyCreateExternalTable { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateExternalTable({})", self)) + Ok(format!("CreateExternalTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_function.rs b/src/expr/create_function.rs index 6f3c3f0f..c02ceebb 100644 --- a/src/expr/create_function.rs +++ b/src/expr/create_function.rs @@ -163,7 +163,7 @@ impl PyCreateFunction { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateFunction({})", self)) + Ok(format!("CreateFunction({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_index.rs b/src/expr/create_index.rs index 13dadbc3..0f4b5011 100644 --- a/src/expr/create_index.rs +++ b/src/expr/create_index.rs @@ -110,7 +110,7 @@ impl PyCreateIndex { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateIndex({})", self)) + Ok(format!("CreateIndex({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_memory_table.rs b/src/expr/create_memory_table.rs index 8872b2d4..37f4d342 100644 --- a/src/expr/create_memory_table.rs +++ b/src/expr/create_memory_table.rs @@ -78,7 +78,7 @@ impl PyCreateMemoryTable { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateMemoryTable({})", self)) + Ok(format!("CreateMemoryTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_view.rs b/src/expr/create_view.rs index 87bb7687..718e404d 100644 --- a/src/expr/create_view.rs +++ b/src/expr/create_view.rs @@ -75,7 +75,7 @@ impl PyCreateView { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateView({})", self)) + Ok(format!("CreateView({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/describe_table.rs b/src/expr/describe_table.rs index 5658a13f..6c48f3c7 100644 --- a/src/expr/describe_table.rs +++ b/src/expr/describe_table.rs @@ -61,7 +61,7 @@ impl PyDescribeTable { } fn __repr__(&self) -> PyResult { - Ok(format!("DescribeTable({})", self)) + Ok(format!("DescribeTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/distinct.rs b/src/expr/distinct.rs index b62b776f..889e7099 100644 --- a/src/expr/distinct.rs +++ b/src/expr/distinct.rs @@ -48,8 +48,7 @@ impl Display for PyDistinct { Distinct::All(input) => write!( f, "Distinct ALL - \nInput: {:?}", - input, + \nInput: {input:?}", ), Distinct::On(distinct_on) => { write!( @@ -71,7 +70,7 @@ impl PyDistinct { } fn __repr__(&self) -> PyResult { - Ok(format!("Distinct({})", self)) + Ok(format!("Distinct({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_catalog_schema.rs b/src/expr/drop_catalog_schema.rs index b7420a99..b4a4c521 100644 --- a/src/expr/drop_catalog_schema.rs +++ b/src/expr/drop_catalog_schema.rs @@ -101,7 +101,7 @@ impl PyDropCatalogSchema { } fn __repr__(&self) -> PyResult { - Ok(format!("DropCatalogSchema({})", self)) + Ok(format!("DropCatalogSchema({self})")) } } diff --git a/src/expr/drop_function.rs b/src/expr/drop_function.rs index 9fbd78fd..fca9eb94 100644 --- a/src/expr/drop_function.rs +++ b/src/expr/drop_function.rs @@ -76,7 +76,7 @@ impl PyDropFunction { } fn __repr__(&self) -> PyResult { - Ok(format!("DropFunction({})", self)) + Ok(format!("DropFunction({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_table.rs b/src/expr/drop_table.rs index 96983c1c..3f442539 100644 --- a/src/expr/drop_table.rs +++ b/src/expr/drop_table.rs @@ -70,7 +70,7 @@ impl PyDropTable { } fn __repr__(&self) -> PyResult { - Ok(format!("DropTable({})", self)) + Ok(format!("DropTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_view.rs b/src/expr/drop_view.rs index 1d1ab1e5..6196c8bb 100644 --- a/src/expr/drop_view.rs +++ b/src/expr/drop_view.rs @@ -83,7 +83,7 @@ impl PyDropView { } fn __repr__(&self) -> PyResult { - Ok(format!("DropView({})", self)) + Ok(format!("DropView({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/empty_relation.rs b/src/expr/empty_relation.rs index a1534ac1..75821342 100644 --- a/src/expr/empty_relation.rs +++ b/src/expr/empty_relation.rs @@ -65,7 +65,7 @@ impl PyEmptyRelation { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } fn __name__(&self) -> PyResult { diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 9bdb667c..4fcb600c 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -72,7 +72,7 @@ impl PyFilter { } fn __repr__(&self) -> String { - format!("Filter({})", self) + format!("Filter({self})") } } diff --git a/src/expr/join.rs b/src/expr/join.rs index 76ec532e..b8d1d9da 100644 --- a/src/expr/join.rs +++ b/src/expr/join.rs @@ -177,7 +177,7 @@ impl PyJoin { } fn __repr__(&self) -> PyResult { - Ok(format!("Join({})", self)) + Ok(format!("Join({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/like.rs b/src/expr/like.rs index 2e1f060b..f180f5d4 100644 --- a/src/expr/like.rs +++ b/src/expr/like.rs @@ -75,7 +75,7 @@ impl PyLike { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } @@ -133,7 +133,7 @@ impl PyILike { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } @@ -191,6 +191,6 @@ impl PySimilarTo { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } diff --git a/src/expr/limit.rs b/src/expr/limit.rs index c2a33ff8..92552814 100644 --- a/src/expr/limit.rs +++ b/src/expr/limit.rs @@ -81,7 +81,7 @@ impl PyLimit { } fn __repr__(&self) -> PyResult { - Ok(format!("Limit({})", self)) + Ok(format!("Limit({self})")) } } diff --git a/src/expr/projection.rs b/src/expr/projection.rs index dc7e5e3c..b5a9ef34 100644 --- a/src/expr/projection.rs +++ b/src/expr/projection.rs @@ -85,7 +85,7 @@ impl PyProjection { } fn __repr__(&self) -> PyResult { - Ok(format!("Projection({})", self)) + Ok(format!("Projection({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/recursive_query.rs b/src/expr/recursive_query.rs index 65181f7d..2517b741 100644 --- a/src/expr/recursive_query.rs +++ b/src/expr/recursive_query.rs @@ -89,7 +89,7 @@ impl PyRecursiveQuery { } fn __repr__(&self) -> PyResult { - Ok(format!("RecursiveQuery({})", self)) + Ok(format!("RecursiveQuery({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/repartition.rs b/src/expr/repartition.rs index 3e782d6a..48b5e704 100644 --- a/src/expr/repartition.rs +++ b/src/expr/repartition.rs @@ -108,7 +108,7 @@ impl PyRepartition { } fn __repr__(&self) -> PyResult { - Ok(format!("Repartition({})", self)) + Ok(format!("Repartition({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/sort.rs b/src/expr/sort.rs index ed494759..79a8aee5 100644 --- a/src/expr/sort.rs +++ b/src/expr/sort.rs @@ -87,7 +87,7 @@ impl PySort { } fn __repr__(&self) -> PyResult { - Ok(format!("Sort({})", self)) + Ok(format!("Sort({self})")) } } diff --git a/src/expr/sort_expr.rs b/src/expr/sort_expr.rs index 12f74e4d..79e35d97 100644 --- a/src/expr/sort_expr.rs +++ b/src/expr/sort_expr.rs @@ -85,6 +85,6 @@ impl PySortExpr { } fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/subquery.rs b/src/expr/subquery.rs index 5ebfe692..77f56f9a 100644 --- a/src/expr/subquery.rs +++ b/src/expr/subquery.rs @@ -62,7 +62,7 @@ impl PySubquery { } fn __repr__(&self) -> PyResult { - Ok(format!("Subquery({})", self)) + Ok(format!("Subquery({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/subquery_alias.rs b/src/expr/subquery_alias.rs index 267a4d48..3302e7f2 100644 --- a/src/expr/subquery_alias.rs +++ b/src/expr/subquery_alias.rs @@ -72,7 +72,7 @@ impl PySubqueryAlias { } fn __repr__(&self) -> PyResult { - Ok(format!("SubqueryAlias({})", self)) + Ok(format!("SubqueryAlias({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs index 6a0d53f0..32996468 100644 --- a/src/expr/table_scan.rs +++ b/src/expr/table_scan.rs @@ -136,7 +136,7 @@ impl PyTableScan { } fn __repr__(&self) -> PyResult { - Ok(format!("TableScan({})", self)) + Ok(format!("TableScan({self})")) } } diff --git a/src/expr/union.rs b/src/expr/union.rs index 5a08ccc1..e0b22139 100644 --- a/src/expr/union.rs +++ b/src/expr/union.rs @@ -66,7 +66,7 @@ impl PyUnion { } fn __repr__(&self) -> PyResult { - Ok(format!("Union({})", self)) + Ok(format!("Union({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs index 8e70e099..c8833347 100644 --- a/src/expr/unnest.rs +++ b/src/expr/unnest.rs @@ -66,7 +66,7 @@ impl PyUnnest { } fn __repr__(&self) -> PyResult { - Ok(format!("Unnest({})", self)) + Ok(format!("Unnest({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/unnest_expr.rs b/src/expr/unnest_expr.rs index 2234d24b..634186ed 100644 --- a/src/expr/unnest_expr.rs +++ b/src/expr/unnest_expr.rs @@ -58,7 +58,7 @@ impl PyUnnestExpr { } fn __repr__(&self) -> PyResult { - Ok(format!("UnnestExpr({})", self)) + Ok(format!("UnnestExpr({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/window.rs b/src/expr/window.rs index 052d9eeb..a408731c 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -185,8 +185,7 @@ impl PyWindowFrame { "groups" => WindowFrameUnits::Groups, _ => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }; @@ -197,8 +196,7 @@ impl PyWindowFrame { WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), WindowFrameUnits::Groups => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }, @@ -210,8 +208,7 @@ impl PyWindowFrame { WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)), WindowFrameUnits::Groups => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }, @@ -236,7 +233,7 @@ impl PyWindowFrame { /// Get a String representation of this window frame fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/physical_plan.rs b/src/physical_plan.rs index f0be45c6..49db643e 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -78,8 +78,7 @@ impl PyExecutionPlan { let proto_plan = datafusion_proto::protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { PyRuntimeError::new_err(format!( - "Unable to decode logical node from serialized bytes: {}", - e + "Unable to decode logical node from serialized bytes: {e}" )) })?; diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 198d68bd..97d32047 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -201,8 +201,7 @@ impl PyLogicalPlan { let proto_plan = datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { PyRuntimeError::new_err(format!( - "Unable to decode logical node from serialized bytes: {}", - e + "Unable to decode logical node from serialized bytes: {e}" )) })?; diff --git a/src/utils.rs b/src/utils.rs index f4e121fd..3b30de5d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -109,8 +109,7 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe let capsule_name = capsule_name.unwrap().to_str()?; if capsule_name != name { return Err(PyValueError::new_err(format!( - "Expected name '{}' in PyCapsule, instead got '{}'", - name, capsule_name + "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" ))); } @@ -127,7 +126,7 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult