diff --git a/Cargo.lock b/Cargo.lock index 112167cb4..a3e9336cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "zstd", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -1503,6 +1509,7 @@ dependencies = [ "datafusion-proto", "datafusion-substrait", "futures", + "log", "mimalloc", "object_store", "prost", @@ -1510,6 +1517,7 @@ dependencies = [ "pyo3", "pyo3-async-runtimes", "pyo3-build-config", + "pyo3-log", "tokio", "url", "uuid", @@ -2953,6 +2961,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.24.2" diff --git a/Cargo.toml b/Cargo.toml index 4135e64e2..1f7895a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} +pyo3-log = "0.12.4" arrow = { version = "55.1.0", features = ["pyarrow"] } datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "48.0.0", optional = true } @@ -49,6 +50,7 @@ async-trait = "0.1.88" futures = "0.3" object_store = { version = "0.12.1", features = ["aws", "gcp", "azure", "http"] } url = "2" +log = "0.4.27" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/docs/source/user-guide/data-sources.rst b/docs/source/user-guide/data-sources.rst index ba5967c97..9c95d58e0 100644 --- a/docs/source/user-guide/data-sources.rst +++ b/docs/source/user-guide/data-sources.rst @@ -185,3 +185,59 @@ the interface as describe in the :ref:`Custom Table Provider `_ is provided in the DataFusion repository. + +Catalog +======= + +A common technique for organizing tables is using a three level hierarchical approach. DataFusion +supports this form of organizing using the :py:class:`~datafusion.catalog.Catalog`, +:py:class:`~datafusion.catalog.Schema`, and :py:class:`~datafusion.catalog.Table`. By default, +a :py:class:`~datafusion.context.SessionContext` comes with a single Catalog and a single Schema +with the names ``datafusion`` and ``default``, respectively. + +The default implementation uses an in-memory approach to the catalog and schema. We have support +for adding additional in-memory catalogs and schemas. This can be done like in the following +example: + +.. code-block:: python + + from datafusion.catalog import Catalog, Schema + + my_catalog = Catalog.memory_catalog() + my_schema = Schema.memory_schema() + + my_catalog.register_schema("my_schema_name", my_schema) + + ctx.register_catalog("my_catalog_name", my_catalog) + +You could then register tables in ``my_schema`` and access them either through the DataFrame +API or via sql commands such as ``"SELECT * from my_catalog_name.my_schema_name.my_table"``. + +User Defined Catalog and Schema +------------------------------- + +If the in-memory catalogs are insufficient for your uses, there are two approaches you can take +to implementing a custom catalog and/or schema. In the below discussion, we describe how to +implement these for a Catalog, but the approach to implementing for a Schema is nearly +identical. + +DataFusion supports Catalogs written in either Rust or Python. If you write a Catalog in Rust, +you will need to export it as a Python library via PyO3. There is a complete example of a +catalog implemented this way in the +`examples folder `_ +of our repository. Writing catalog providers in Rust provides typically can lead to significant +performance improvements over the Python based approach. + +To implement a Catalog in Python, you will need to inherit from the abstract base class +:py:class:`~datafusion.catalog.CatalogProvider`. There are examples in the +`unit tests `_ of +implementing a basic Catalog in Python where we simply keep a dictionary of the +registered Schemas. + +One important note for developers is that when we have a Catalog defined in Python, we have +two different ways of accessing this Catalog. First, we register the catalog with a Rust +wrapper. This allows for any rust based code to call the Python functions as necessary. +Second, if the user access the Catalog via the Python API, we identify this and return back +the original Python object that implements the Catalog. This is an important distinction +for developers because we do *not* return a Python wrapper around the Rust wrapper of the +original Python object. diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a1..e5a1ca8d1 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -1448,6 +1448,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "datafusion", "datafusion-ffi", "pyo3", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b9..319163554 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 000000000..72aadf64c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyCatalogProvider + + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ["units", "price"] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 000000000..54e61cf3e --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner.register_table("my_table".to_string(), table).unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table(&self, name: &str) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass( + name = "MyCatalogProvider", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = + FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b65..3a4cf2247 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::catalog_provider::MyCatalogProvider; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; use pyo3::prelude::*; +pub(crate) mod catalog_provider; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -26,5 +28,6 @@ pub(crate) mod table_provider; fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index fd7f4fc06..e9d2dba75 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -92,6 +92,7 @@ "TableFunction", "WindowFrame", "WindowUDF", + "catalog", "col", "column", "common", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 67ab3ead2..536b3a790 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -19,18 +19,33 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Protocol import datafusion._internal as df_internal if TYPE_CHECKING: import pyarrow as pa +try: + from warnings import deprecated # Python 3.13+ +except ImportError: + from typing_extensions import deprecated # Python 3.12 + + +__all__ = [ + "Catalog", + "CatalogProvider", + "Schema", + "SchemaProvider", + "Table", +] + class Catalog: """DataFusion data catalog.""" - def __init__(self, catalog: df_internal.Catalog) -> None: + def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog @@ -38,39 +53,95 @@ def __repr__(self) -> str: """Print a string representation of the catalog.""" return self.catalog.__repr__() - def names(self) -> list[str]: - """Returns the list of databases in this catalog.""" - return self.catalog.names() + def names(self) -> set[str]: + """This is an alias for `schema_names`.""" + return self.schema_names() + + def schema_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog.schema_names() + + @staticmethod + def memory_catalog() -> Catalog: + """Create an in-memory catalog provider.""" + catalog = df_internal.catalog.RawCatalog.memory_catalog() + return Catalog(catalog) - def database(self, name: str = "public") -> Database: + def schema(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" - return Database(self.catalog.database(name)) + schema = self.catalog.schema(name) + + return ( + Schema(schema) + if isinstance(schema, df_internal.catalog.RawSchema) + else schema + ) + + @deprecated("Use `schema` instead.") + def database(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + return self.schema(name) + + def register_schema(self, name, schema) -> Schema | None: + """Register a schema with this catalog.""" + if isinstance(schema, Schema): + return self.catalog.register_schema(name, schema._raw_schema) + return self.catalog.register_schema(name, schema) + + def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: + """Deregister a schema from this catalog.""" + return self.catalog.deregister_schema(name, cascade) -class Database: - """DataFusion Database.""" +class Schema: + """DataFusion Schema.""" - def __init__(self, db: df_internal.Database) -> None: + def __init__(self, schema: df_internal.catalog.RawSchema) -> None: """This constructor is not typically called by the end user.""" - self.db = db + self._raw_schema = schema def __repr__(self) -> str: - """Print a string representation of the database.""" - return self.db.__repr__() + """Print a string representation of the schema.""" + return self._raw_schema.__repr__() + + @staticmethod + def memory_schema() -> Schema: + """Create an in-memory schema provider.""" + schema = df_internal.catalog.RawSchema.memory_schema() + return Schema(schema) def names(self) -> set[str]: - """Returns the list of all tables in this database.""" - return self.db.names() + """This is an alias for `table_names`.""" + return self.table_names() + + def table_names(self) -> set[str]: + """Returns the list of all tables in this schema.""" + return self._raw_schema.table_names def table(self, name: str) -> Table: - """Return the table with the given ``name`` from this database.""" - return Table(self.db.table(name)) + """Return the table with the given ``name`` from this schema.""" + return Table(self._raw_schema.table(name)) + + def register_table(self, name, table) -> None: + """Register a table provider in this schema.""" + if isinstance(table, Table): + return self._raw_schema.register_table(name, table.table) + return self._raw_schema.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Deregister a table provider from this schema.""" + return self._raw_schema.deregister_table(name) + + +@deprecated("Use `Schema` instead.") +class Database(Schema): + """See `Schema`.""" class Table: """DataFusion table.""" - def __init__(self, table: df_internal.Table) -> None: + def __init__(self, table: df_internal.catalog.RawTable) -> None: """This constructor is not typically called by the end user.""" self.table = table @@ -78,6 +149,11 @@ def __repr__(self) -> str: """Print a string representation of the table.""" return self.table.__repr__() + @staticmethod + def from_dataset(dataset: pa.dataset.Dataset) -> Table: + """Turn a pyarrow Dataset into a Table.""" + return Table(df_internal.catalog.RawTable.from_dataset(dataset)) + @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" @@ -87,3 +163,86 @@ def schema(self) -> pa.Schema: def kind(self) -> str: """Returns the kind of table.""" return self.table.kind + + +class CatalogProvider(ABC): + """Abstract class for defining a Python based Catalog Provider.""" + + @abstractmethod + def schema_names(self) -> set[str]: + """Set of the names of all schemas in this catalog.""" + ... + + @abstractmethod + def schema(self, name: str) -> Schema | None: + """Retrieve a specific schema from this catalog.""" + ... + + def register_schema( # noqa: B027 + self, name: str, schema: SchemaProviderExportable | SchemaProvider | Schema + ) -> None: + """Add a schema to this catalog. + + This method is optional. If your catalog provides a fixed list of schemas, you + do not need to implement this method. + """ + + def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027 + """Remove a schema from this catalog. + + This method is optional. If your catalog provides a fixed list of schemas, you + do not need to implement this method. + + Args: + name: The name of the schema to remove. + cascade: If true, deregister the tables within the schema. + """ + + +class SchemaProvider(ABC): + """Abstract class for defining a Python based Schema Provider.""" + + def owner_name(self) -> str | None: + """Returns the owner of the schema. + + This is an optional method. The default return is None. + """ + return None + + @abstractmethod + def table_names(self) -> set[str]: + """Set of the names of all tables in this schema.""" + ... + + @abstractmethod + def table(self, name: str) -> Table | None: + """Retrieve a specific table from this schema.""" + ... + + def register_table(self, name: str, table: Table) -> None: # noqa: B027 + """Add a table from this schema. + + This method is optional. If your schema provides a fixed list of tables, you do + not need to implement this method. + """ + + def deregister_table(self, name, cascade: bool) -> None: # noqa: B027 + """Remove a table from this schema. + + This method is optional. If your schema provides a fixed list of tables, you do + not need to implement this method. + """ + + @abstractmethod + def table_exist(self, name: str) -> bool: + """Returns true if the table exists in this schema.""" + ... + + +class SchemaProviderExportable(Protocol): + """Type hint for object that has __datafusion_schema_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html + """ + + def __datafusion_schema_provider__(self) -> object: ... diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5b99b0d26..bce51d644 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -29,7 +29,7 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -from datafusion.catalog import Catalog, Table +from datafusion.catalog import Catalog, CatalogProvider, Table from datafusion.dataframe import DataFrame from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream @@ -80,6 +80,15 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -749,6 +758,19 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def catalog_names(self) -> set[str]: + """Returns the list of catalogs in this context.""" + return self.ctx.catalog_names() + + def register_catalog_provider( + self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog + ) -> None: + """Register a catalog provider.""" + if isinstance(provider, Catalog): + self.ctx.register_catalog_provider(name, provider.catalog) + else: + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 991e6875a..61cb09438 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -760,19 +760,16 @@ def join_on( exprs = [expr.expr for expr in on_exprs] return DataFrame(self.df.join_on(right.df, exprs, how)) - def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: - """Return a DataFrame with the explanation of its plan so far. + def explain(self, verbose: bool = False, analyze: bool = False) -> None: + """Print an explanation of the DataFrame's plan so far. If ``analyze`` is specified, runs the plan and reports metrics. Args: verbose: If ``True``, more details will be included. analyze: If ``Tru`e``, the plan will run and metrics reported. - - Returns: - DataFrame with the explanation of its plan. """ - return DataFrame(self.df.explain(verbose, analyze)) + self.df.explain(verbose, analyze) def logical_plan(self) -> LogicalPlan: """Return the unoptimized ``LogicalPlan``. diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 23b328458..1f9ecbfc3 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations +import datafusion as dfn import pyarrow as pa +import pyarrow.dataset as ds import pytest +from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -27,9 +31,9 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ["public"] + assert default.names() == {"public"} - for db in [default.database("public"), default.database()]: + for db in [default.schema("public"), default.schema()]: assert db.names() == {"csv1", "csv", "csv2"} table = db.table("csv") @@ -41,3 +45,168 @@ def test_basic(ctx, database): pa.field("float", pa.float64(), nullable=True), ] ) + + +def create_dataset() -> Table: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + return Table.from_dataset(dataset) + + +class CustomSchemaProvider(dfn.catalog.SchemaProvider): + def __init__(self): + self.tables = {"table1": create_dataset()} + + def table_names(self) -> set[str]: + return set(self.tables.keys()) + + def register_table(self, name: str, table: Table): + self.tables[name] = table + + def deregister_table(self, name, cascade: bool = True): + del self.tables[name] + + def table(self, name: str) -> Table | None: + return self.tables[name] + + def table_exist(self, name: str) -> bool: + return name in self.tables + + +class CustomCatalogProvider(dfn.catalog.CatalogProvider): + def __init__(self): + self.schemas = {"my_schema": CustomSchemaProvider()} + + def schema_names(self) -> set[str]: + return set(self.schemas.keys()) + + def schema(self, name: str): + return self.schemas[name] + + def register_schema(self, name: str, schema: dfn.catalog.Schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade: bool): + del self.schemas[name] + + +def test_python_catalog_provider(ctx: SessionContext): + ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) + + # Check the default catalog provider + assert ctx.catalog("datafusion").names() == {"public"} + + my_catalog = ctx.catalog("my_catalog") + assert my_catalog.names() == {"my_schema"} + + my_catalog.register_schema("second_schema", CustomSchemaProvider()) + assert my_catalog.schema_names() == {"my_schema", "second_schema"} + + my_catalog.deregister_schema("my_schema") + assert my_catalog.schema_names() == {"second_schema"} + + +def test_in_memory_providers(ctx: SessionContext): + catalog = dfn.catalog.Catalog.memory_catalog() + ctx.register_catalog_provider("in_mem_catalog", catalog) + + assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"} + + schema = dfn.catalog.Schema.memory_schema() + catalog.register_schema("in_mem_schema", schema) + + schema.register_table("my_table", create_dataset()) + + batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([1, 2, 3]) + assert batches[0].column(1) == pa.array([4, 5, 6]) + + +def test_python_schema_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.deregister_schema("public") + + catalog.register_schema("test_schema1", CustomSchemaProvider()) + assert catalog.names() == {"test_schema1"} + + catalog.register_schema("test_schema2", CustomSchemaProvider()) + catalog.deregister_schema("test_schema1") + assert catalog.names() == {"test_schema2"} + + +def test_python_table_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.register_schema("custom_schema", CustomSchemaProvider()) + schema = catalog.schema("custom_schema") + + assert schema.table_names() == {"table1"} + + schema.deregister_table("table1") + schema.register_table("table2", create_dataset()) + assert schema.table_names() == {"table2"} + + # Use the default schema instead of our custom schema + + schema = catalog.schema() + + schema.register_table("table3", create_dataset()) + assert schema.table_names() == {"table3"} + + schema.deregister_table("table3") + schema.register_table("table4", create_dataset()) + assert schema.table_names() == {"table4"} + + +def test_in_end_to_end_python_providers(ctx: SessionContext): + """Test registering all python providers and running a query against them.""" + + all_catalog_names = [ + "datafusion", + "custom_catalog", + "in_mem_catalog", + ] + + all_schema_names = [ + "custom_schema", + "in_mem_schema", + ] + + ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider()) + ctx.register_catalog_provider( + all_catalog_names[2], dfn.catalog.Catalog.memory_catalog() + ) + + for catalog_name in all_catalog_names: + catalog = ctx.catalog(catalog_name) + + # Clean out previous schemas if they exist so we can start clean + for schema_name in catalog.schema_names(): + catalog.deregister_schema(schema_name, cascade=False) + + catalog.register_schema(all_schema_names[0], CustomSchemaProvider()) + catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema()) + + for schema_name in all_schema_names: + schema = catalog.schema(schema_name) + + for table_name in schema.table_names(): + schema.deregister_table(table_name) + + schema.register_table("test_table", create_dataset()) + + for catalog_name in all_catalog_names: + for schema_name in all_schema_names: + table_full_name = f"{catalog_name}.{schema_name}.test_table" + + batches = ctx.sql(f"select * from {table_full_name}").collect() + + assert len(batches) == 1 + assert batches[0].column(0) == pa.array([1, 2, 3]) + assert batches[0].column(1) == pa.array([4, 5, 6]) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 4a15ac9cf..6dbcc0d5e 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -57,7 +57,7 @@ def test_runtime_configs(tmp_path, path_to_str): ctx = SessionContext(config, runtime) assert ctx is not None - db = ctx.catalog("foo").database("bar") + db = ctx.catalog("foo").schema("bar") assert db is not None @@ -70,7 +70,7 @@ def test_temporary_files(tmp_path, path_to_str): ctx = SessionContext(config, runtime) assert ctx is not None - db = ctx.catalog("foo").database("bar") + db = ctx.catalog("foo").schema("bar") assert db is not None @@ -91,7 +91,7 @@ def test_create_context_with_all_valid_args(): ctx = SessionContext(config, runtime) # verify that at least some of the arguments worked - ctx.catalog("foo").database("bar") + ctx.catalog("foo").schema("bar") with pytest.raises(KeyError): ctx.catalog("datafusion") @@ -105,7 +105,7 @@ def test_register_record_batches(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -121,7 +121,7 @@ def test_create_dataframe_registers_unique_table_name(ctx): ) df = ctx.create_dataframe([[batch]]) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -141,7 +141,7 @@ def test_create_dataframe_registers_with_defined_table_name(ctx): ) df = ctx.create_dataframe([[batch]], name="tbl") - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -155,7 +155,7 @@ def test_from_arrow_table(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -200,7 +200,7 @@ def test_from_arrow_table_with_name(ctx): # convert to DataFrame with optional name df = ctx.from_arrow(table, name="tbl") - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert tables[0] == "tbl" @@ -213,7 +213,7 @@ def test_from_arrow_table_empty(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -228,7 +228,7 @@ def test_from_arrow_table_empty_no_schema(ctx): # convert to DataFrame df = ctx.from_arrow(table) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -246,7 +246,7 @@ def test_from_pylist(ctx): ] df = ctx.from_pylist(data) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -260,7 +260,7 @@ def test_from_pydict(ctx): data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = ctx.from_pydict(data) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -276,7 +276,7 @@ def test_from_pandas(ctx): pandas_df = pd.DataFrame(data) df = ctx.from_pandas(pandas_df) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -292,7 +292,7 @@ def test_from_polars(ctx): polars_df = pd.DataFrame(data) df = ctx.from_polars(polars_df) - tables = list(ctx.catalog().database().names()) + tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 @@ -303,7 +303,7 @@ def test_from_polars(ctx): def test_register_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") @@ -313,7 +313,7 @@ def test_register_table(ctx, database): def test_read_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") @@ -323,7 +323,7 @@ def test_read_table(ctx, database): def test_deregister_table(ctx, database): default = ctx.catalog() - public = default.database("public") + public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} ctx.deregister_table("csv") @@ -339,7 +339,7 @@ def test_register_dataset(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -356,7 +356,7 @@ def test_dataset_filter(ctx, capfd): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") # Make sure the filter was pushed down in Physical Plan @@ -455,7 +455,7 @@ def test_dataset_filter_nested_data(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} df = ctx.table("t") diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 41cee4ef3..c383edc60 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -75,7 +75,7 @@ def test_register_csv(ctx, tmp_path): ) ctx.register_csv("csv3", path, schema=alternative_schema) - assert ctx.catalog().database().names() == { + assert ctx.catalog().schema().names() == { "csv", "csv1", "csv2", @@ -150,7 +150,7 @@ def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) ctx.register_parquet("t1", str(path)) - assert ctx.catalog().database().names() == {"t", "t1"} + assert ctx.catalog().schema().names() == {"t", "t1"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -188,7 +188,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_ty parquet_pruning=True, file_extension=".parquet", ) - assert ctx.catalog().database().names() == {"datapp"} + assert ctx.catalog().schema().names() == {"datapp"} result = ctx.sql("SELECT grp, COUNT(*) AS cnt FROM datapp GROUP BY grp").collect() result = pa.Table.from_batches(result) @@ -204,7 +204,7 @@ def test_register_dataset(ctx, tmp_path, path_to_str): dataset = ds.dataset(path, format="parquet") ctx.register_dataset("t", dataset) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -251,7 +251,7 @@ def test_register_json(ctx, tmp_path): ) ctx.register_json("json3", path, schema=alternative_schema) - assert ctx.catalog().database().names() == { + assert ctx.catalog().schema().names() == { "json", "json1", "json2", @@ -308,7 +308,7 @@ def test_execute(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) ctx.register_parquet("t", path) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} # count result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").collect() @@ -451,18 +451,10 @@ def test_udf( id="datetime_ns", ), # Not writtable to parquet - pytest.param( - helpers.data_timedelta("s"), id="timedelta_s", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("ms"), id="timedelta_ms", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("us"), id="timedelta_us", marks=pytest.mark.xfail - ), - pytest.param( - helpers.data_timedelta("ns"), id="timedelta_ns", marks=pytest.mark.xfail - ), + pytest.param(helpers.data_timedelta("s"), id="timedelta_s"), + pytest.param(helpers.data_timedelta("ms"), id="timedelta_ms"), + pytest.param(helpers.data_timedelta("us"), id="timedelta_us"), + pytest.param(helpers.data_timedelta("ns"), id="timedelta_ns"), ], ) def test_simple_select(ctx, tmp_path, arr): @@ -524,7 +516,7 @@ def test_register_listing_table( schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, ) - assert ctx.catalog().database().names() == {"my_table"} + assert ctx.catalog().schema().names() == {"my_table"} result = ctx.sql( "SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp" diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py index f367a447d..43aa327d4 100644 --- a/python/tests/test_substrait.py +++ b/python/tests/test_substrait.py @@ -34,7 +34,7 @@ def test_substrait_serialization(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} # For now just make sure the method calls blow up substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM t", ctx) @@ -59,7 +59,7 @@ def test_substrait_file_serialization(ctx, tmp_path, path_to_str): ctx.register_record_batches("t", [[batch]]) - assert ctx.catalog().database().names() == {"t"} + assert ctx.catalog().schema().names() == {"t"} path = tmp_path / "substrait_plan" path = str(path) if path_to_str else path diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08cb..17d4ec3b8 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,44 +15,54 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::utils::wait_for_future; +use crate::dataset::Dataset; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::utils::{validate_pycapsule, wait_for_future}; +use async_trait::async_trait; +use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider}; +use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; +use pyo3::IntoPyObjectExt; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; -#[pyclass(name = "Catalog", module = "datafusion", subclass)] +#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] +#[derive(Clone)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub struct PyDatabase { - pub database: Arc, +#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +#[derive(Clone)] +pub struct PySchema { + pub schema: Arc, } -#[pyclass(name = "Table", module = "datafusion", subclass)] +#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] +#[derive(Clone)] pub struct PyTable { pub table: Arc, } -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { +impl From> for PyCatalog { + fn from(catalog: Arc) -> Self { Self { catalog } } } -impl PyDatabase { - pub fn new(database: Arc) -> Self { - Self { database } +impl From> for PySchema { + fn from(schema: Arc) -> Self { + Self { schema } } } @@ -68,36 +78,109 @@ impl PyTable { #[pymethods] impl PyCatalog { - fn names(&self) -> Vec { - self.catalog.schema_names() + #[new] + fn new(catalog: PyObject) -> Self { + let catalog_provider = + Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc; + catalog_provider.into() + } + + #[staticmethod] + fn memory_catalog() -> Self { + let catalog_provider = + Arc::new(MemoryCatalogProvider::default()) as Arc; + catalog_provider.into() + } + + fn schema_names(&self) -> HashSet { + self.catalog.schema_names().into_iter().collect() } #[pyo3(signature = (name="public"))] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), - None => Err(PyKeyError::new_err(format!( - "Database with name {name} doesn't exist." - ))), - } + fn schema(&self, name: &str) -> PyResult { + let schema = self + .catalog + .schema(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)), + None => PySchema::from(schema).into_py_any(py), + } + }) + } + + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { + let capsule = schema_provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + Arc::new(provider) as Arc + } else { + match schema_provider.extract::() { + Ok(py_schema) => py_schema.schema, + Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into())) + as Arc, + } + }; + + let _ = self + .catalog + .register_schema(name, provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> { + let _ = self + .catalog + .deregister_schema(name, cascade) + .map_err(py_datafusion_err)?; + + Ok(()) } fn __repr__(&self) -> PyResult { - Ok(format!( - "Catalog(schema_names=[{}])", - self.names().join(";") - )) + let mut names: Vec = self.schema_names().into_iter().collect(); + names.sort(); + Ok(format!("Catalog(schema_names=[{}])", names.join(", "))) } } #[pymethods] -impl PyDatabase { - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() +impl PySchema { + #[new] + fn new(schema_provider: PyObject) -> Self { + let schema_provider = + Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc; + schema_provider.into() + } + + #[staticmethod] + fn memory_schema() -> Self { + let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc; + schema_provider.into() + } + + #[getter] + fn table_names(&self) -> HashSet { + self.schema.table_names().into_iter().collect() } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))?? { + if let Some(table) = wait_for_future(py, self.schema.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( @@ -107,14 +190,49 @@ impl PyDatabase { } fn __repr__(&self) -> PyResult { - Ok(format!( - "Database(table_names=[{}])", - Vec::from_iter(self.names()).join(";") - )) + let mut names: Vec = self.table_names().into_iter().collect(); + names.sort(); + Ok(format!("Schema(table_names=[{}])", names.join(";"))) } - // register_table - // deregister_table + fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let capsule = table_provider + .getattr("__datafusion_table_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + Arc::new(provider) as Arc + } else { + match table_provider.extract::() { + Ok(py_table) => py_table.table, + Err(_) => { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + } + } + }; + + let _ = self + .schema + .register_table(name.to_string(), provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_table(&self, name: &str) -> PyResult<()> { + let _ = self + .schema + .deregister_table(name) + .map_err(py_datafusion_err)?; + + Ok(()) + } } #[pymethods] @@ -125,6 +243,14 @@ impl PyTable { self.table.schema().to_pyarrow(py) } + #[staticmethod] + fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult { + let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?) + as Arc; + + Ok(Self::new(ds)) + } + /// Get the type of this table for metadata/catalog purposes. #[getter] fn kind(&self) -> &str { @@ -145,3 +271,285 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[derive(Debug)] +pub(crate) struct RustWrappedPySchemaProvider { + schema_provider: PyObject, + owner_name: Option, +} + +impl RustWrappedPySchemaProvider { + pub fn new(schema_provider: PyObject) -> Self { + let owner_name = Python::with_gil(|py| { + schema_provider + .bind(py) + .getattr("owner_name") + .ok() + .map(|name| name.to_string()) + }); + + Self { + schema_provider, + owner_name, + } + } + + fn table_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let py_table_method = provider.getattr("table")?; + + let py_table = py_table_method.call((name,), None)?; + if py_table.is_none() { + return Ok(None); + } + + if py_table.hasattr("__datafusion_table_provider__")? { + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + if let Ok(inner_table) = py_table.getattr("table") { + if let Ok(inner_table) = inner_table.extract::() { + return Ok(Some(inner_table.table)); + } + } + + match py_table.extract::() { + Ok(py_table) => Ok(Some(py_table.table)), + Err(_) => { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + Ok(Some(Arc::new(ds) as Arc)) + } + } + } + }) + } +} + +#[async_trait] +impl SchemaProvider for RustWrappedPySchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + + provider + .getattr("table_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get table_names: {err}"); + Vec::default() + }) + }) + } + + async fn table( + &self, + name: &str, + ) -> datafusion::common::Result>, DataFusionError> { + self.table_inner(name).map_err(to_datafusion_err) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion::common::Result>> { + let py_table = PyTable::new(table); + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let _ = provider + .call_method1("register_table", (name, py_table)) + .map_err(to_datafusion_err)?; + // Since the definition of `register_table` says that an error + // will be returned if the table already exists, there is no + // case where we want to return a table provider as output. + Ok(None) + }) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let table = provider + .call_method1("deregister_table", (name,)) + .map_err(to_datafusion_err)?; + if table.is_none() { + return Ok(None); + } + + // If we can turn this table provider into a `Dataset`, return it. + // Otherwise, return None. + let dataset = match Dataset::new(&table, py) { + Ok(dataset) => Some(Arc::new(dataset) as Arc), + Err(_) => None, + }; + + Ok(dataset) + }) + } + + fn table_exist(&self, name: &str) -> bool { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .call_method1("table_exist", (name,)) + .and_then(|pyobj| pyobj.extract()) + .unwrap_or(false) + }) + } +} + +#[derive(Debug)] +pub(crate) struct RustWrappedPyCatalogProvider { + pub(crate) catalog_provider: PyObject, +} + +impl RustWrappedPyCatalogProvider { + pub fn new(catalog_provider: PyObject) -> Self { + Self { catalog_provider } + } + + fn schema_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + + let py_schema = provider.call_method1("schema", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + if py_schema.hasattr("__datafusion_schema_provider__")? { + let capsule = provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + if let Ok(inner_schema) = py_schema.getattr("schema") { + if let Ok(inner_schema) = inner_schema.extract::() { + return Ok(Some(inner_schema.schema)); + } + } + match py_schema.extract::() { + Ok(inner_schema) => Ok(Some(inner_schema.schema)), + Err(_) => { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + + Ok(Some(Arc::new(py_schema) as Arc)) + } + } + } + }) + } +} + +#[async_trait] +impl CatalogProvider for RustWrappedPyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + provider + .getattr("schema_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get schema_names: {err}"); + Vec::default() + }) + }) + } + + fn schema(&self, name: &str) -> Option> { + self.schema_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider schema returned error: {err}"); + None + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion::common::Result>> { + // JRIGHT HERE + // let py_schema: PySchema = schema.into(); + Python::with_gil(|py| { + let py_schema = match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(), + None => &PySchema::from(schema) + .into_py_any(py) + .map_err(to_datafusion_err)?, + }; + + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("register_schema", (name, py_schema)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("deregister_schema", (name, cascade)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/common/data_type.rs b/src/common/data_type.rs index f5f8a6b06..5cf9d6e9f 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -172,7 +172,7 @@ impl DataTypeMap { SqlType::DATE, )), DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Interval(interval_unit) => Ok(DataTypeMap::new( DataType::Interval(*interval_unit), @@ -189,7 +189,7 @@ impl DataTypeMap { SqlType::BINARY, )), DataType::FixedSizeBinary(_) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::LargeBinary => Ok(DataTypeMap::new( DataType::LargeBinary, @@ -207,23 +207,22 @@ impl DataTypeMap { SqlType::VARCHAR, )), DataType::List(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - arrow_type + "{arrow_type:?}" )))), DataType::FixedSizeList(_, _) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::LargeList(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Struct(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Union(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new( DataType::Decimal128(*precision, *scale), @@ -236,23 +235,22 @@ impl DataTypeMap { SqlType::DECIMAL, )), DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::RunEndEncoded(_, _) => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", arrow_type)), + DataFusionError::NotImplemented(format!("{arrow_type:?}")), )), DataType::BinaryView => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::Utf8View => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - arrow_type + "{arrow_type:?}" )))), DataType::ListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), DataType::LargeListView(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), + format!("{arrow_type:?}"), ))), } } @@ -379,8 +377,7 @@ impl DataTypeMap { "double" => Ok(DataType::Float64), "byte_array" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( - "Unable to determine Arrow Data Type from Parquet String type: {:?}", - parquet_str_type + "Unable to determine Arrow Data Type from Parquet String type: {parquet_str_type:?}" ))), }; DataTypeMap::map_from_arrow_type(&arrow_dtype?) @@ -404,12 +401,10 @@ impl DataTypeMap { pub fn py_map_from_sql_type(sql_type: &SqlType) -> PyResult { match sql_type { SqlType::ANY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::ARRAY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::BIGINT => Ok(DataTypeMap::new( DataType::Int64, @@ -432,11 +427,10 @@ impl DataTypeMap { SqlType::CHAR, )), SqlType::COLUMN_LIST => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::CURSOR => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::DATE => Ok(DataTypeMap::new( DataType::Date64, @@ -449,8 +443,7 @@ impl DataTypeMap { SqlType::DECIMAL, )), SqlType::DISTINCT => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::DOUBLE => Ok(DataTypeMap::new( DataType::Decimal256(1, 1), @@ -458,7 +451,7 @@ impl DataTypeMap { SqlType::DOUBLE, )), SqlType::DYNAMIC_STAR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::FLOAT => Ok(DataTypeMap::new( DataType::Decimal128(1, 1), @@ -466,8 +459,7 @@ impl DataTypeMap { SqlType::FLOAT, )), SqlType::GEOMETRY => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::INTEGER => Ok(DataTypeMap::new( DataType::Int8, @@ -475,55 +467,52 @@ impl DataTypeMap { SqlType::INTEGER, )), SqlType::INTERVAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::INTERVAL_DAY => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_DAY_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_DAY_MINUTE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_DAY_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_HOUR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_HOUR_MINUTE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_HOUR_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_MINUTE => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_MINUTE_SECOND => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::INTERVAL_MONTH => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_SECOND => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_YEAR => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::INTERVAL_YEAR_MONTH => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::MAP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::MULTISET => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::NULL => Ok(DataTypeMap::new( DataType::Null, @@ -531,20 +520,16 @@ impl DataTypeMap { SqlType::NULL, )), SqlType::OTHER => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::REAL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::ROW => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::SARG => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::SMALLINT => Ok(DataTypeMap::new( DataType::Int16, @@ -552,25 +537,22 @@ impl DataTypeMap { SqlType::SMALLINT, )), SqlType::STRUCTURED => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", sql_type), + format!("{sql_type:?}"), ))), SqlType::SYMBOL => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIME => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIME_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::TIMESTAMP => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::TIMESTAMP_WITH_LOCAL_TIME_ZONE => Err(py_datafusion_err( - DataFusionError::NotImplemented(format!("{:?}", sql_type)), + DataFusionError::NotImplemented(format!("{sql_type:?}")), )), SqlType::TINYINT => Ok(DataTypeMap::new( DataType::Int8, @@ -578,8 +560,7 @@ impl DataTypeMap { SqlType::TINYINT, )), SqlType::UNKNOWN => Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - sql_type + "{sql_type:?}" )))), SqlType::VARBINARY => Ok(DataTypeMap::new( DataType::LargeBinary, @@ -682,8 +663,7 @@ impl PyDataType { "datetime64" => Ok(DataType::Date64), "object" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( - "Unable to determine Arrow Data Type from Arrow String type: {:?}", - arrow_str_type + "Unable to determine Arrow Data Type from Arrow String type: {arrow_str_type:?}" ))), }; Ok(PyDataType { diff --git a/src/context.rs b/src/context.rs index 6ce1f12bc..36133a33d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -69,8 +70,10 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -365,7 +368,7 @@ impl PySessionContext { } else { &upstream_host }; - let url_string = format!("{}{}", scheme, derived_host); + let url_string = format!("{scheme}{derived_host}"); let url = Url::parse(&url_string).unwrap(); self.ctx.runtime_env().register_object_store(&url, store); Ok(()) @@ -614,6 +617,34 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + let provider = if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider + .getattr("__datafusion_catalog_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + Arc::new(provider) as Arc + } else { + match provider.extract::() { + Ok(py_catalog) => py_catalog.catalog, + Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into())) + as Arc, + } + }; + + let _ = self.ctx.register_catalog(name, provider); + + Ok(()) + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, @@ -845,14 +876,24 @@ impl PySessionContext { } #[pyo3(signature = (name="datafusion"))] - pub fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name, - ))), - } + pub fn catalog(&self, name: &str) -> PyResult { + let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( + "Catalog with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), + None => PyCatalog::from(catalog).into_py_any(py), + } + }) + } + + pub fn catalog_names(&self) -> HashSet { + self.ctx.catalog_names().into_iter().collect() } pub fn tables(&self) -> HashSet { diff --git a/src/expr.rs b/src/expr.rs index 6b1d01d65..f1e002367 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -171,12 +171,10 @@ impl PyExpr { Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_bound_py_any(py)?), Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_bound_py_any(py)?), Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!( - "Converting Expr::ScalarFunction to a Python object is not implemented: {:?}", - value + "Converting Expr::ScalarFunction to a Python object is not implemented: {value:?}" ))), Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!( - "Converting Expr::WindowFunction to a Python object is not implemented: {:?}", - value + "Converting Expr::WindowFunction to a Python object is not implemented: {value:?}" ))), Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_bound_py_any(py)?), Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_bound_py_any(py)?), @@ -188,8 +186,7 @@ impl PyExpr { } #[allow(deprecated)] Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!( - "Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}", - qualifier, options + "Converting Expr::Wildcard to a Python object is not implemented : {qualifier:?} {options:?}" ))), Expr::GroupingSet(value) => { Ok(grouping_set::PyGroupingSet::from(value.clone()).into_bound_py_any(py)?) @@ -198,8 +195,7 @@ impl PyExpr { Ok(placeholder::PyPlaceholder::from(value.clone()).into_bound_py_any(py)?) } Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!( - "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}", - data_type, column + "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {data_type:?} - {column:?}" ))), Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_bound_py_any(py)?), } @@ -755,8 +751,7 @@ impl PyExpr { Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type), Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value), _ => Err(py_type_err(format!( - "Non Expr::Literal encountered in types: {:?}", - expr + "Non Expr::Literal encountered in types: {expr:?}" ))), } } diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index a99d83d23..fd4393271 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -116,7 +116,7 @@ impl PyAggregate { } fn __repr__(&self) -> PyResult { - Ok(format!("Aggregate({})", self)) + Ok(format!("Aggregate({self})")) } } diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index c09f116e3..7c5d3d31f 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -75,6 +75,6 @@ impl PyAggregateFunction { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/alias.rs b/src/expr/alias.rs index e8e03cfad..40746f200 100644 --- a/src/expr/alias.rs +++ b/src/expr/alias.rs @@ -64,6 +64,6 @@ impl PyAlias { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/analyze.rs b/src/expr/analyze.rs index 62f93cd26..e8081e95b 100644 --- a/src/expr/analyze.rs +++ b/src/expr/analyze.rs @@ -69,7 +69,7 @@ impl PyAnalyze { } fn __repr__(&self) -> PyResult { - Ok(format!("Analyze({})", self)) + Ok(format!("Analyze({self})")) } } diff --git a/src/expr/between.rs b/src/expr/between.rs index a2cac1442..817f1baae 100644 --- a/src/expr/between.rs +++ b/src/expr/between.rs @@ -71,6 +71,6 @@ impl PyBetween { } fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/column.rs b/src/expr/column.rs index 365dbc0d2..50f316f1c 100644 --- a/src/expr/column.rs +++ b/src/expr/column.rs @@ -45,7 +45,7 @@ impl PyColumn { /// Get the column relation fn relation(&self) -> Option { - self.col.relation.as_ref().map(|r| format!("{}", r)) + self.col.relation.as_ref().map(|r| format!("{r}")) } /// Get the fully-qualified column name diff --git a/src/expr/copy_to.rs b/src/expr/copy_to.rs index ebfcb8ebc..473dabfed 100644 --- a/src/expr/copy_to.rs +++ b/src/expr/copy_to.rs @@ -106,7 +106,7 @@ impl PyCopyTo { } fn __repr__(&self) -> PyResult { - Ok(format!("CopyTo({})", self)) + Ok(format!("CopyTo({self})")) } fn __name__(&self) -> PyResult { @@ -129,7 +129,7 @@ impl Display for PyFileType { #[pymethods] impl PyFileType { fn __repr__(&self) -> PyResult { - Ok(format!("FileType({})", self)) + Ok(format!("FileType({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_catalog.rs b/src/expr/create_catalog.rs index f4ea0f517..d2d2ee8f6 100644 --- a/src/expr/create_catalog.rs +++ b/src/expr/create_catalog.rs @@ -81,7 +81,7 @@ impl PyCreateCatalog { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateCatalog({})", self)) + Ok(format!("CreateCatalog({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_catalog_schema.rs b/src/expr/create_catalog_schema.rs index 85f447e1e..e794962f5 100644 --- a/src/expr/create_catalog_schema.rs +++ b/src/expr/create_catalog_schema.rs @@ -81,7 +81,7 @@ impl PyCreateCatalogSchema { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateCatalogSchema({})", self)) + Ok(format!("CreateCatalogSchema({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_external_table.rs b/src/expr/create_external_table.rs index 01ce7d0ca..3e35af006 100644 --- a/src/expr/create_external_table.rs +++ b/src/expr/create_external_table.rs @@ -164,7 +164,7 @@ impl PyCreateExternalTable { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateExternalTable({})", self)) + Ok(format!("CreateExternalTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_function.rs b/src/expr/create_function.rs index 6f3c3f0ff..c02ceebb1 100644 --- a/src/expr/create_function.rs +++ b/src/expr/create_function.rs @@ -163,7 +163,7 @@ impl PyCreateFunction { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateFunction({})", self)) + Ok(format!("CreateFunction({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_index.rs b/src/expr/create_index.rs index 13dadbc3f..0f4b5011a 100644 --- a/src/expr/create_index.rs +++ b/src/expr/create_index.rs @@ -110,7 +110,7 @@ impl PyCreateIndex { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateIndex({})", self)) + Ok(format!("CreateIndex({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_memory_table.rs b/src/expr/create_memory_table.rs index 8872b2d47..37f4d3420 100644 --- a/src/expr/create_memory_table.rs +++ b/src/expr/create_memory_table.rs @@ -78,7 +78,7 @@ impl PyCreateMemoryTable { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateMemoryTable({})", self)) + Ok(format!("CreateMemoryTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/create_view.rs b/src/expr/create_view.rs index 87bb76876..718e404d0 100644 --- a/src/expr/create_view.rs +++ b/src/expr/create_view.rs @@ -75,7 +75,7 @@ impl PyCreateView { } fn __repr__(&self) -> PyResult { - Ok(format!("CreateView({})", self)) + Ok(format!("CreateView({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/describe_table.rs b/src/expr/describe_table.rs index 5658a13f2..6c48f3c77 100644 --- a/src/expr/describe_table.rs +++ b/src/expr/describe_table.rs @@ -61,7 +61,7 @@ impl PyDescribeTable { } fn __repr__(&self) -> PyResult { - Ok(format!("DescribeTable({})", self)) + Ok(format!("DescribeTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/distinct.rs b/src/expr/distinct.rs index b62b776f8..889e7099d 100644 --- a/src/expr/distinct.rs +++ b/src/expr/distinct.rs @@ -48,8 +48,7 @@ impl Display for PyDistinct { Distinct::All(input) => write!( f, "Distinct ALL - \nInput: {:?}", - input, + \nInput: {input:?}", ), Distinct::On(distinct_on) => { write!( @@ -71,7 +70,7 @@ impl PyDistinct { } fn __repr__(&self) -> PyResult { - Ok(format!("Distinct({})", self)) + Ok(format!("Distinct({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_catalog_schema.rs b/src/expr/drop_catalog_schema.rs index b7420a99c..b4a4c521c 100644 --- a/src/expr/drop_catalog_schema.rs +++ b/src/expr/drop_catalog_schema.rs @@ -101,7 +101,7 @@ impl PyDropCatalogSchema { } fn __repr__(&self) -> PyResult { - Ok(format!("DropCatalogSchema({})", self)) + Ok(format!("DropCatalogSchema({self})")) } } diff --git a/src/expr/drop_function.rs b/src/expr/drop_function.rs index 9fbd78fdc..fca9eb94b 100644 --- a/src/expr/drop_function.rs +++ b/src/expr/drop_function.rs @@ -76,7 +76,7 @@ impl PyDropFunction { } fn __repr__(&self) -> PyResult { - Ok(format!("DropFunction({})", self)) + Ok(format!("DropFunction({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_table.rs b/src/expr/drop_table.rs index 96983c1cf..3f442539a 100644 --- a/src/expr/drop_table.rs +++ b/src/expr/drop_table.rs @@ -70,7 +70,7 @@ impl PyDropTable { } fn __repr__(&self) -> PyResult { - Ok(format!("DropTable({})", self)) + Ok(format!("DropTable({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/drop_view.rs b/src/expr/drop_view.rs index 1d1ab1e59..6196c8bb5 100644 --- a/src/expr/drop_view.rs +++ b/src/expr/drop_view.rs @@ -83,7 +83,7 @@ impl PyDropView { } fn __repr__(&self) -> PyResult { - Ok(format!("DropView({})", self)) + Ok(format!("DropView({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/empty_relation.rs b/src/expr/empty_relation.rs index a1534ac15..758213423 100644 --- a/src/expr/empty_relation.rs +++ b/src/expr/empty_relation.rs @@ -65,7 +65,7 @@ impl PyEmptyRelation { /// Get a String representation of this column fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } fn __name__(&self) -> PyResult { diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 9bdb667cd..4fcb600cd 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -72,7 +72,7 @@ impl PyFilter { } fn __repr__(&self) -> String { - format!("Filter({})", self) + format!("Filter({self})") } } diff --git a/src/expr/join.rs b/src/expr/join.rs index 76ec532e7..b8d1d9da7 100644 --- a/src/expr/join.rs +++ b/src/expr/join.rs @@ -177,7 +177,7 @@ impl PyJoin { } fn __repr__(&self) -> PyResult { - Ok(format!("Join({})", self)) + Ok(format!("Join({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/like.rs b/src/expr/like.rs index 2e1f060bd..f180f5d4c 100644 --- a/src/expr/like.rs +++ b/src/expr/like.rs @@ -75,7 +75,7 @@ impl PyLike { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } @@ -133,7 +133,7 @@ impl PyILike { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } @@ -191,6 +191,6 @@ impl PySimilarTo { } fn __repr__(&self) -> String { - format!("Like({})", self) + format!("Like({self})") } } diff --git a/src/expr/limit.rs b/src/expr/limit.rs index c2a33ff89..92552814e 100644 --- a/src/expr/limit.rs +++ b/src/expr/limit.rs @@ -81,7 +81,7 @@ impl PyLimit { } fn __repr__(&self) -> PyResult { - Ok(format!("Limit({})", self)) + Ok(format!("Limit({self})")) } } diff --git a/src/expr/projection.rs b/src/expr/projection.rs index dc7e5e3c1..b5a9ef34a 100644 --- a/src/expr/projection.rs +++ b/src/expr/projection.rs @@ -85,7 +85,7 @@ impl PyProjection { } fn __repr__(&self) -> PyResult { - Ok(format!("Projection({})", self)) + Ok(format!("Projection({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/recursive_query.rs b/src/expr/recursive_query.rs index 65181f7d3..2517b7417 100644 --- a/src/expr/recursive_query.rs +++ b/src/expr/recursive_query.rs @@ -89,7 +89,7 @@ impl PyRecursiveQuery { } fn __repr__(&self) -> PyResult { - Ok(format!("RecursiveQuery({})", self)) + Ok(format!("RecursiveQuery({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/repartition.rs b/src/expr/repartition.rs index 3e782d6af..48b5e7041 100644 --- a/src/expr/repartition.rs +++ b/src/expr/repartition.rs @@ -108,7 +108,7 @@ impl PyRepartition { } fn __repr__(&self) -> PyResult { - Ok(format!("Repartition({})", self)) + Ok(format!("Repartition({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/sort.rs b/src/expr/sort.rs index ed4947591..79a8aee50 100644 --- a/src/expr/sort.rs +++ b/src/expr/sort.rs @@ -87,7 +87,7 @@ impl PySort { } fn __repr__(&self) -> PyResult { - Ok(format!("Sort({})", self)) + Ok(format!("Sort({self})")) } } diff --git a/src/expr/sort_expr.rs b/src/expr/sort_expr.rs index 12f74e4d8..79e35d978 100644 --- a/src/expr/sort_expr.rs +++ b/src/expr/sort_expr.rs @@ -85,6 +85,6 @@ impl PySortExpr { } fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/expr/subquery.rs b/src/expr/subquery.rs index 5ebfe6927..77f56f9a9 100644 --- a/src/expr/subquery.rs +++ b/src/expr/subquery.rs @@ -62,7 +62,7 @@ impl PySubquery { } fn __repr__(&self) -> PyResult { - Ok(format!("Subquery({})", self)) + Ok(format!("Subquery({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/subquery_alias.rs b/src/expr/subquery_alias.rs index 267a4d485..3302e7f23 100644 --- a/src/expr/subquery_alias.rs +++ b/src/expr/subquery_alias.rs @@ -72,7 +72,7 @@ impl PySubqueryAlias { } fn __repr__(&self) -> PyResult { - Ok(format!("SubqueryAlias({})", self)) + Ok(format!("SubqueryAlias({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs index 6a0d53f0f..329964687 100644 --- a/src/expr/table_scan.rs +++ b/src/expr/table_scan.rs @@ -136,7 +136,7 @@ impl PyTableScan { } fn __repr__(&self) -> PyResult { - Ok(format!("TableScan({})", self)) + Ok(format!("TableScan({self})")) } } diff --git a/src/expr/union.rs b/src/expr/union.rs index 5a08ccc13..e0b221398 100644 --- a/src/expr/union.rs +++ b/src/expr/union.rs @@ -66,7 +66,7 @@ impl PyUnion { } fn __repr__(&self) -> PyResult { - Ok(format!("Union({})", self)) + Ok(format!("Union({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs index 8e70e0990..c8833347f 100644 --- a/src/expr/unnest.rs +++ b/src/expr/unnest.rs @@ -66,7 +66,7 @@ impl PyUnnest { } fn __repr__(&self) -> PyResult { - Ok(format!("Unnest({})", self)) + Ok(format!("Unnest({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/unnest_expr.rs b/src/expr/unnest_expr.rs index 2234d24b1..634186ed8 100644 --- a/src/expr/unnest_expr.rs +++ b/src/expr/unnest_expr.rs @@ -58,7 +58,7 @@ impl PyUnnestExpr { } fn __repr__(&self) -> PyResult { - Ok(format!("UnnestExpr({})", self)) + Ok(format!("UnnestExpr({self})")) } fn __name__(&self) -> PyResult { diff --git a/src/expr/window.rs b/src/expr/window.rs index 052d9eeb4..a408731c2 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -185,8 +185,7 @@ impl PyWindowFrame { "groups" => WindowFrameUnits::Groups, _ => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }; @@ -197,8 +196,7 @@ impl PyWindowFrame { WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), WindowFrameUnits::Groups => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }, @@ -210,8 +208,7 @@ impl PyWindowFrame { WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)), WindowFrameUnits::Groups => { return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, + "{units:?}", )))); } }, @@ -236,7 +233,7 @@ impl PyWindowFrame { /// Get a String representation of this window frame fn __repr__(&self) -> String { - format!("{}", self) + format!("{self}") } } diff --git a/src/functions.rs b/src/functions.rs index b2bafcb65..b40500b8b 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log))?; + m.add_wrapped(wrap_pyfunction!(self::log))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; m.add_wrapped(wrap_pyfunction!(lower))?; diff --git a/src/lib.rs b/src/lib.rs index 1293eee3c..29d3f41da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,10 +77,10 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// datafusion directory. #[pymodule] fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + // Initialize logging + pyo3_log::init(); + // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -98,6 +98,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + let catalog = PyModule::new(py, "catalog")?; + catalog::init_module(&catalog)?; + m.add_submodule(&catalog)?; + // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?; common::init_module(&common)?; diff --git a/src/physical_plan.rs b/src/physical_plan.rs index f0be45c6a..49db643e1 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -78,8 +78,7 @@ impl PyExecutionPlan { let proto_plan = datafusion_proto::protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { PyRuntimeError::new_err(format!( - "Unable to decode logical node from serialized bytes: {}", - e + "Unable to decode logical node from serialized bytes: {e}" )) })?; diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 198d68bdc..97d320470 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -201,8 +201,7 @@ impl PyLogicalPlan { let proto_plan = datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { PyRuntimeError::new_err(format!( - "Unable to decode logical node from serialized bytes: {}", - e + "Unable to decode logical node from serialized bytes: {e}" )) })?; diff --git a/src/utils.rs b/src/utils.rs index f4e121fd5..3b30de5de 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -109,8 +109,7 @@ pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyRe let capsule_name = capsule_name.unwrap().to_str()?; if capsule_name != name { return Err(PyValueError::new_err(format!( - "Expected name '{}' in PyCapsule, instead got '{}'", - name, capsule_name + "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" ))); } @@ -127,7 +126,7 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult