diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index df7e1813b..abc324db8 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -16,6 +16,7 @@ # under the License. import gzip import os +import datetime as dt import pyarrow as pa import pyarrow.dataset as ds @@ -303,6 +304,59 @@ def test_dataset_filter(ctx, capfd): assert result[0].column(1) == pa.array([-3]) +def test_pyarrow_predicate_pushdown_is_null(ctx, capfd): + """Ensure that pyarrow filter gets pushed down for `IsNull`""" + # create a RecordBatch and register it as a pyarrow.dataset.Dataset + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])], + names=["a", "b", "c"], + ) + dataset = ds.dataset([batch]) + ctx.register_dataset("t", dataset) + # Make sure the filter was pushed down in Physical Plan + df = ctx.sql("SELECT a FROM t WHERE c is NULL") + df.explain() + captured = capfd.readouterr() + assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out + + result = df.collect() + assert result[0].column(0) == pa.array([2]) + + +def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd): + """Ensure that pyarrow filter gets pushed down for timestamp""" + # Ref: https://github.com/apache/datafusion-python/issues/703 + + # create pyarrow dataset with no actual files + col_type = pa.timestamp("ns", "+00:00") + nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type) + pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem()) + pa_dataset_format = pa.dataset.ParquetFileFormat() + pa_dataset_partition = pa.dataset.field("a") <= nyd_2000 + fragments = [ + # NOTE: we never actually make this file. + # Working predicate pushdown means it never gets accessed + pa_dataset_format.make_fragment( + "1.parquet", + filesystem=pa_dataset_fs, + partition_expression=pa_dataset_partition, + ) + ] + pa_dataset = pa.dataset.FileSystemDataset( + fragments, + pa.schema([pa.field("a", col_type)]), + pa_dataset_format, + pa_dataset_fs, + ) + + ctx.register_dataset("t", pa_dataset) + + # the partition for our only fragment is for a < 2000-01-01. + # so querying for a > 2024-01-01 should not touch any files + df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'") + assert df.collect() == [] + + def test_dataset_filter_nested_data(ctx): # create Arrow StructArrays to test nested data types data = pa.StructArray.from_arrays( diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index fca885121..ff447e1ab 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -21,6 +21,7 @@ use pyo3::prelude::*; use std::convert::TryFrom; use std::result::Result; +use arrow::pyarrow::ToPyArrow; use datafusion_common::{Column, ScalarValue}; use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator}; @@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result, Data let ret: Result, DataFusionError> = exprs .iter() .map(|expr| match expr { + // TODO: should we also leverage `ScalarValue::to_pyarrow` here? Expr::Literal(v) => match v { ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)), ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)), @@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let op_module = Python::import_bound(py, "operator")?; let pc_expr: Result, DataFusionError> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), - Expr::Literal(v) => match v { - ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?), - ScalarValue::Int8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Float32(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), - ScalarValue::Float64(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), - ScalarValue::Utf8(Some(s)) => Ok(pc.getattr("scalar")?.call1((s,))?), - _ => Err(DataFusionError::Common(format!( - "PyArrow can't handle ScalarValue: {v:?}" - ))), - }, + Expr::Literal(scalar) => Ok(scalar.to_pyarrow(py)?.into_bound(py)), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; @@ -138,8 +124,13 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let expr = PyArrowFilterExpression::try_from(expr.as_ref())? .0 .into_bound(py); - // TODO: this expression does not seems like it should be `call_method0` - Ok(expr.clone().call_method1("is_null", (expr,))?) + + // https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression.is_null + // Whether floating-point NaNs are considered null. + let nan_is_null = false; + + let res = expr.call_method1("is_null", (nan_is_null,))?; + Ok(res) } Expr::Between(Between { expr,