Skip to content

Commit c6d406a

Browse files
committed
Workin progress on python catalog
1 parent 7b09073 commit c6d406a

File tree

2 files changed

+163
-15
lines changed

2 files changed

+163
-15
lines changed

src/catalog.rs

Lines changed: 152 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,23 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::collections::HashSet;
19-
use std::sync::Arc;
20-
21-
use pyo3::exceptions::PyKeyError;
22-
use pyo3::prelude::*;
23-
24-
use crate::errors::{PyDataFusionError, PyDataFusionResult};
25-
use crate::utils::wait_for_future;
18+
use crate::dataset::Dataset;
19+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20+
use crate::utils::{validate_pycapsule, wait_for_future};
21+
use async_trait::async_trait;
22+
use datafusion::common::DataFusionError;
2623
use datafusion::{
2724
arrow::pyarrow::ToPyArrow,
2825
catalog::{CatalogProvider, SchemaProvider},
2926
datasource::{TableProvider, TableType},
3027
};
28+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
29+
use pyo3::exceptions::PyKeyError;
30+
use pyo3::prelude::*;
31+
use pyo3::types::PyCapsule;
32+
use std::any::Any;
33+
use std::collections::HashSet;
34+
use std::sync::Arc;
3135

3236
#[pyclass(name = "Catalog", module = "datafusion", subclass)]
3337
pub struct PyCatalog {
@@ -50,8 +54,8 @@ impl PyCatalog {
5054
}
5155
}
5256

53-
impl PyDatabase {
54-
pub fn new(database: Arc<dyn SchemaProvider>) -> Self {
57+
impl From<Arc<dyn SchemaProvider>> for PyDatabase {
58+
fn from(database: Arc<dyn SchemaProvider>) -> Self {
5559
Self { database }
5660
}
5761
}
@@ -75,7 +79,7 @@ impl PyCatalog {
7579
#[pyo3(signature = (name="public"))]
7680
fn database(&self, name: &str) -> PyResult<PyDatabase> {
7781
match self.catalog.schema(name) {
78-
Some(database) => Ok(PyDatabase::new(database)),
82+
Some(database) => Ok(database.into()),
7983
None => Err(PyKeyError::new_err(format!(
8084
"Database with name {name} doesn't exist."
8185
))),
@@ -92,6 +96,13 @@ impl PyCatalog {
9296

9397
#[pymethods]
9498
impl PyDatabase {
99+
#[new]
100+
fn new(schema_provider: PyObject) -> Self {
101+
let schema_provider =
102+
Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc<dyn SchemaProvider>;
103+
schema_provider.into()
104+
}
105+
95106
fn names(&self) -> HashSet<String> {
96107
self.database.table_names().into_iter().collect()
97108
}
@@ -145,3 +156,133 @@ impl PyTable {
145156
// fn has_exact_statistics
146157
// fn supports_filter_pushdown
147158
}
159+
160+
#[derive(Debug)]
161+
struct RustWrappedPySchemaProvider {
162+
schema_provider: PyObject,
163+
owner_name: Option<String>,
164+
}
165+
166+
impl RustWrappedPySchemaProvider {
167+
fn new(schema_provider: PyObject) -> Self {
168+
let owner_name = Python::with_gil(|py| {
169+
schema_provider
170+
.bind(py)
171+
.getattr("owner_name")
172+
.ok()
173+
.map(|name| name.to_string())
174+
});
175+
176+
Self {
177+
schema_provider,
178+
owner_name,
179+
}
180+
}
181+
182+
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
183+
Python::with_gil(|py| {
184+
let provider = self.schema_provider.bind(py);
185+
let py_table_method = provider.getattr("table")?;
186+
187+
let py_table = py_table_method.call((name,), None)?;
188+
if py_table.is_none() {
189+
return Ok(None);
190+
}
191+
192+
if py_table.hasattr("__datafusion_table_provider__")? {
193+
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
194+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
195+
validate_pycapsule(capsule, "datafusion_table_provider")?;
196+
197+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
198+
let provider: ForeignTableProvider = provider.into();
199+
200+
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
201+
} else {
202+
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
203+
204+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
205+
}
206+
})
207+
}
208+
}
209+
210+
#[async_trait]
211+
impl SchemaProvider for RustWrappedPySchemaProvider {
212+
fn owner_name(&self) -> Option<&str> {
213+
self.owner_name.as_deref()
214+
}
215+
216+
fn as_any(&self) -> &dyn Any {
217+
self
218+
}
219+
220+
fn table_names(&self) -> Vec<String> {
221+
Python::with_gil(|py| {
222+
let provider = self.schema_provider.bind(py);
223+
provider
224+
.getattr("table_names")
225+
.and_then(|names| names.extract::<Vec<String>>())
226+
.unwrap_or_default()
227+
})
228+
}
229+
230+
async fn table(
231+
&self,
232+
name: &str,
233+
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
234+
self.table_inner(name).map_err(to_datafusion_err)
235+
}
236+
237+
fn register_table(
238+
&self,
239+
name: String,
240+
table: Arc<dyn TableProvider>,
241+
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
242+
let py_table = PyTable::new(table);
243+
Python::with_gil(|py| {
244+
let provider = self.schema_provider.bind(py);
245+
let _ = provider
246+
.call_method1("register_table", (name, py_table))
247+
.map_err(to_datafusion_err)?;
248+
// Since the definition of `register_table` says that an error
249+
// will be returned if the table already exists, there is no
250+
// case where we want to return a table provider as output.
251+
Ok(None)
252+
})
253+
}
254+
255+
fn deregister_table(
256+
&self,
257+
name: &str,
258+
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
259+
Python::with_gil(|py| {
260+
let provider = self.schema_provider.bind(py);
261+
let table = provider
262+
.call_method1("deregister_table", (name,))
263+
.map_err(to_datafusion_err)?;
264+
if table.is_none() {
265+
return Ok(None);
266+
}
267+
268+
// If we can turn this table provider into a `Dataset`, return it.
269+
// Otherwise, return None.
270+
let dataset = match Dataset::new(&table, py) {
271+
Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
272+
Err(_) => None,
273+
};
274+
275+
Ok(dataset)
276+
})
277+
}
278+
279+
fn table_exist(&self, name: &str) -> bool {
280+
Python::with_gil(|py| {
281+
let provider = self.schema_provider.bind(py);
282+
provider
283+
.call_method1("table_exist", (name,))
284+
.and_then(|pyobj| pyobj.extract())
285+
.unwrap_or(false)
286+
})
287+
}
288+
}

src/context.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ use datafusion::physical_plan::SendableRecordBatchStream;
7070
use datafusion::prelude::{
7171
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7272
};
73-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7473
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
74+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7575
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7676
use tokio::task::JoinHandle;
7777

@@ -590,17 +590,24 @@ impl PySessionContext {
590590
provider: Bound<'_, PyAny>,
591591
) -> PyDataFusionResult<()> {
592592
if provider.hasattr("__datafusion_catalog_provider__")? {
593-
let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?;
593+
let capsule = provider
594+
.getattr("__datafusion_catalog_provider__")?
595+
.call0()?;
594596
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
595597
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
596598

597599
let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
598600
let provider: ForeignCatalogProvider = provider.into();
599601

600-
let option: Option<Arc<dyn CatalogProvider>> = self.ctx.register_catalog(name, Arc::new(provider));
602+
let option: Option<Arc<dyn CatalogProvider>> =
603+
self.ctx.register_catalog(name, Arc::new(provider));
601604
match option {
602605
Some(existing) => {
603-
println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names());
606+
println!(
607+
"Catalog '{}' already existed, schema names: {:?}",
608+
name,
609+
existing.schema_names()
610+
);
604611
}
605612
None => {
606613
println!("Catalog '{}' registered successfully", name);

0 commit comments

Comments
 (0)