Skip to content

Commit 05f5356

Browse files
committed
Exposing FFI to python
1 parent 278a33e commit 05f5356

File tree

7 files changed

+297
-0
lines changed

7 files changed

+297
-0
lines changed

examples/datafusion-ffi-example/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/datafusion-ffi-example/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"]
2727
arrow = { version = "55.0.0" }
2828
arrow-array = { version = "55.0.0" }
2929
arrow-schema = { version = "55.0.0" }
30+
async-trait = "0.1.88"
3031

3132
[build-dependencies]
3233
pyo3-build-config = "0.23"
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pyarrow as pa
21+
22+
from datafusion import SessionContext
23+
from datafusion_ffi_example import MyCatalogProvider
24+
25+
def test_catalog_provider():
26+
ctx = SessionContext()
27+
28+
my_catalog_name = "my_catalog"
29+
expected_schema_name = "my_schema"
30+
expected_table_name = "my_table"
31+
expected_table_columns = ['units', 'price']
32+
33+
catalog_provider = MyCatalogProvider()
34+
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
35+
my_catalog = ctx.catalog(my_catalog_name)
36+
37+
my_catalog_schemas = my_catalog.names()
38+
assert expected_schema_name in my_catalog_schemas
39+
my_database = my_catalog.database(expected_schema_name)
40+
assert expected_table_name in my_database.names()
41+
my_table = my_database.table(expected_table_name)
42+
assert expected_table_columns == my_table.schema.names
43+
44+
ctx.register_table(expected_table_name, my_table)
45+
expected_df = ctx.sql(f"SELECT * FROM {expected_table_name}").to_pandas()
46+
assert len(expected_df) == 5
47+
assert expected_table_columns == expected_df.columns.tolist()
48+
49+
result = ctx.table(f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}").collect()
50+
assert len(result) == 2
51+
52+
col0_result = [r.column(0) for r in result]
53+
col1_result = [r.column(1) for r in result]
54+
expected_col0 = [
55+
pa.array([10, 20, 30], type=pa.int32()),
56+
pa.array([5, 7], type=pa.int32()),
57+
]
58+
expected_col1 = [
59+
pa.array([1, 2, 5], type=pa.float64()),
60+
pa.array([1.5, 2.5], type=pa.float64()),
61+
]
62+
assert col0_result == expected_col0
63+
assert col1_result == expected_col1
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{any::Any, fmt::Debug, sync::Arc};
19+
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
20+
21+
use arrow::datatypes::Schema;
22+
use async_trait::async_trait;
23+
use datafusion::{
24+
catalog::{
25+
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
26+
TableProvider,
27+
},
28+
common::exec_err,
29+
datasource::MemTable,
30+
error::{DataFusionError, Result},
31+
};
32+
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
33+
use pyo3::types::PyCapsule;
34+
35+
pub fn my_table() -> Arc<dyn TableProvider + 'static> {
36+
use arrow::datatypes::{DataType, Field};
37+
use datafusion::common::record_batch;
38+
39+
let schema = Arc::new(Schema::new(vec![
40+
Field::new("units", DataType::Int32, true),
41+
Field::new("price", DataType::Float64, true),
42+
]));
43+
44+
let partitions = vec![
45+
record_batch!(
46+
("units", Int32, vec![10, 20, 30]),
47+
("price", Float64, vec![1.0, 2.0, 5.0])
48+
)
49+
.unwrap(),
50+
record_batch!(
51+
("units", Int32, vec![5, 7]),
52+
("price", Float64, vec![1.5, 2.5])
53+
)
54+
.unwrap(),
55+
];
56+
57+
Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
58+
}
59+
60+
#[derive(Debug)]
61+
pub struct FixedSchemaProvider {
62+
inner: MemorySchemaProvider,
63+
}
64+
65+
impl Default for FixedSchemaProvider {
66+
fn default() -> Self {
67+
let inner = MemorySchemaProvider::new();
68+
69+
let table = my_table();
70+
71+
let _ = inner
72+
.register_table("my_table".to_string(), table)
73+
.unwrap();
74+
75+
Self { inner }
76+
}
77+
}
78+
79+
#[async_trait]
80+
impl SchemaProvider for FixedSchemaProvider {
81+
fn as_any(&self) -> &dyn Any {
82+
self
83+
}
84+
85+
fn table_names(&self) -> Vec<String> {
86+
self.inner.table_names()
87+
}
88+
89+
async fn table(
90+
&self,
91+
name: &str,
92+
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
93+
self.inner.table(name).await
94+
}
95+
96+
fn register_table(
97+
&self,
98+
name: String,
99+
table: Arc<dyn TableProvider>,
100+
) -> Result<Option<Arc<dyn TableProvider>>> {
101+
self.inner.register_table(name, table)
102+
}
103+
104+
fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
105+
self.inner.deregister_table(name)
106+
}
107+
108+
fn table_exist(&self, name: &str) -> bool {
109+
self.inner.table_exist(name)
110+
}
111+
}
112+
113+
114+
/// This catalog provider is intended only for unit tests. It prepopulates with one
115+
/// schema and only allows for schemas named after four types of fruit.
116+
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
117+
#[derive(Debug)]
118+
pub(crate) struct MyCatalogProvider {
119+
inner: MemoryCatalogProvider,
120+
}
121+
122+
impl Default for MyCatalogProvider {
123+
fn default() -> Self {
124+
let inner = MemoryCatalogProvider::new();
125+
126+
let schema_name: &str = "my_schema";
127+
let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default()));
128+
129+
Self { inner }
130+
}
131+
}
132+
133+
impl CatalogProvider for MyCatalogProvider {
134+
fn as_any(&self) -> &dyn Any {
135+
self
136+
}
137+
138+
fn schema_names(&self) -> Vec<String> {
139+
self.inner.schema_names()
140+
}
141+
142+
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
143+
self.inner.schema(name)
144+
}
145+
146+
fn register_schema(
147+
&self,
148+
name: &str,
149+
schema: Arc<dyn SchemaProvider>,
150+
) -> Result<Option<Arc<dyn SchemaProvider>>> {
151+
self.inner.register_schema(name, schema)
152+
}
153+
154+
fn deregister_schema(
155+
&self,
156+
name: &str,
157+
cascade: bool,
158+
) -> Result<Option<Arc<dyn SchemaProvider>>> {
159+
self.inner.deregister_schema(name, cascade)
160+
}
161+
}
162+
163+
#[pymethods]
164+
impl MyCatalogProvider {
165+
#[new]
166+
pub fn new() -> Self {
167+
Self {
168+
inner: Default::default(),
169+
}
170+
}
171+
172+
pub fn __datafusion_catalog_provider__<'py>(
173+
&self,
174+
py: Python<'py>,
175+
) -> PyResult<Bound<'py, PyCapsule>> {
176+
let name = cr"datafusion_catalog_provider".into();
177+
let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);
178+
179+
PyCapsule::new(py, catalog_provider, Some(name))
180+
}
181+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
use crate::table_function::MyTableFunction;
1919
use crate::table_provider::MyTableProvider;
20+
use crate::catalog_provider::MyCatalogProvider;
2021
use pyo3::prelude::*;
2122

2223
pub(crate) mod table_function;
2324
pub(crate) mod table_provider;
25+
pub(crate) mod catalog_provider;
2426

2527
#[pymodule]
2628
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
2729
m.add_class::<MyTableProvider>()?;
2830
m.add_class::<MyTableFunction>()?;
31+
m.add_class::<MyCatalogProvider>()?;
2932
Ok(())
3033
}

python/datafusion/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ class TableProviderExportable(Protocol):
7878
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
7979

8080

81+
class CatalogProviderExportable(Protocol):
82+
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
83+
84+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
85+
"""
86+
def __datafusion_catalog_provider__(self) -> object: ...
87+
88+
8189
class SessionConfig:
8290
"""Session configuration options."""
8391

@@ -742,6 +750,12 @@ def deregister_table(self, name: str) -> None:
742750
"""Remove a table from the session."""
743751
self.ctx.deregister_table(name)
744752

753+
def register_catalog_provider(
754+
self, name: str, provider: CatalogProviderExportable
755+
) -> None:
756+
"""Register a catalog provider."""
757+
self.ctx.register_catalog_provider(name, provider)
758+
745759
def register_table_provider(
746760
self, name: str, provider: TableProviderExportable
747761
) -> None:

src/context.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
52+
use datafusion::catalog::CatalogProvider;
5253
use datafusion::common::TableReference;
5354
use datafusion::common::{exec_err, ScalarValue};
5455
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -70,6 +71,7 @@ use datafusion::prelude::{
7071
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7172
};
7273
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
74+
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
7375
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7476
use tokio::task::JoinHandle;
7577

@@ -582,6 +584,38 @@ impl PySessionContext {
582584
Ok(())
583585
}
584586

587+
pub fn register_catalog_provider(
588+
&mut self,
589+
name: &str,
590+
provider: Bound<'_, PyAny>,
591+
) -> PyDataFusionResult<()> {
592+
if provider.hasattr("__datafusion_catalog_provider__")? {
593+
let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?;
594+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
595+
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
596+
597+
let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
598+
let provider: ForeignCatalogProvider = provider.into();
599+
600+
let option: Option<Arc<dyn CatalogProvider>> = self.ctx.register_catalog(name, Arc::new(provider));
601+
match option {
602+
Some(existing) => {
603+
println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names());
604+
}
605+
None => {
606+
println!("Catalog '{}' registered successfully", name);
607+
}
608+
}
609+
610+
Ok(())
611+
} else {
612+
Err(crate::errors::PyDataFusionError::Common(
613+
"__datafusion_catalog_provider__ does not exist on Catalog Provider object."
614+
.to_string(),
615+
))
616+
}
617+
}
618+
585619
/// Construct datafusion dataframe from Arrow Table
586620
pub fn register_table_provider(
587621
&mut self,

0 commit comments

Comments
 (0)