Skip to content

feat: upgrade df48 dependency #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 100 additions & 84 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ substrait = ["dep:datafusion-substrait"]
tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]}
arrow = { version = "55.0.0", features = ["pyarrow"] }
datafusion = { version = "47.0.0", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "47.0.0", optional = true }
datafusion-proto = { version = "47.0.0" }
datafusion-ffi = { version = "47.0.0" }
arrow = { version = "55.1.0", features = ["pyarrow"] }
datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "48.0.0", optional = true }
datafusion-proto = { version = "48.0.0" }
datafusion-ffi = { version = "48.0.0" }
prost = "0.13.1" # keep in line with `datafusion-substrait`
uuid = { version = "1.16", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
Expand Down
19 changes: 19 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
See https://datafusion.apache.org/python for more information.
"""

from __future__ import annotations

from typing import Any

try:
import importlib.metadata as importlib_metadata
except ImportError:
Expand Down Expand Up @@ -130,3 +134,18 @@ def str_lit(value):
def lit(value) -> Expr:
"""Create a literal expression."""
return Expr.literal(value)


def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""Creates a new expression representing a scalar value with metadata.

Args:
value: A valid PyArrow scalar value or easily castable to one.
metadata: Metadata to attach to the expression.
"""
return Expr.literal_with_metadata(value, metadata)


def lit_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""Alias for literal_with_metadata."""
return literal_with_metadata(value, metadata)
12 changes: 12 additions & 0 deletions python/datafusion/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(self, catalog: df_internal.Catalog) -> None:
"""This constructor is not typically called by the end user."""
self.catalog = catalog

def __repr__(self) -> str:
"""Print a string representation of the catalog."""
return self.catalog.__repr__()

def names(self) -> list[str]:
"""Returns the list of databases in this catalog."""
return self.catalog.names()
Expand All @@ -50,6 +54,10 @@ def __init__(self, db: df_internal.Database) -> None:
"""This constructor is not typically called by the end user."""
self.db = db

def __repr__(self) -> str:
"""Print a string representation of the database."""
return self.db.__repr__()

def names(self) -> set[str]:
"""Returns the list of all tables in this database."""
return self.db.names()
Expand All @@ -66,6 +74,10 @@ def __init__(self, table: df_internal.Table) -> None:
"""This constructor is not typically called by the end user."""
self.table = table

def __repr__(self) -> str:
"""Print a string representation of the table."""
return self.table.__repr__()

@property
def schema(self) -> pa.Schema:
"""Returns the schema associated with this table."""
Expand Down
4 changes: 4 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def __init__(

self.ctx = SessionContextInternal(config, runtime)

def __repr__(self) -> str:
"""Print a string representation of the Session Context."""
return self.ctx.__repr__()

@classmethod
def global_ctx(cls) -> SessionContext:
"""Retrieve the global context as a `SessionContext` wrapper.
Expand Down
18 changes: 18 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,20 @@ def literal(value: Any) -> Expr:
value = pa.scalar(value)
return Expr(expr_internal.RawExpr.literal(value))

@staticmethod
def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""Creates a new expression representing a scalar value with metadata.
Args:
value: A valid PyArrow scalar value or easily castable to one.
metadata: Metadata to attach to the expression.
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)

return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))

@staticmethod
def string_literal(value: str) -> Expr:
"""Creates a new expression representing a UTF8 literal value.
Expand Down Expand Up @@ -1172,6 +1186,10 @@ def __init__(
end_bound = end_bound.cast(pa.uint64())
self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound)

def __repr__(self) -> str:
"""Print a string representation of the window frame."""
return self.window_frame.__repr__()

def get_frame_units(self) -> str:
"""Returns the window frame units for the bounds."""
return self.window_frame.get_frame_units()
Expand Down
12 changes: 12 additions & 0 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def __init__(
name, func, input_types, return_type, str(volatility)
)

def __repr__(self) -> str:
"""Print a string representation of the Scalar UDF."""
return self._udf.__repr__()

def __call__(self, *args: Expr) -> Expr:
"""Execute the UDF.

Expand Down Expand Up @@ -268,6 +272,10 @@ def __init__(
str(volatility),
)

def __repr__(self) -> str:
"""Print a string representation of the Aggregate UDF."""
return self._udaf.__repr__()

def __call__(self, *args: Expr) -> Expr:
"""Execute the UDAF.

Expand Down Expand Up @@ -604,6 +612,10 @@ def __init__(
name, func, input_types, return_type, str(volatility)
)

def __repr__(self) -> str:
"""Print a string representation of the Window UDF."""
return self._udwf.__repr__()

def __call__(self, *args: Expr) -> Expr:
"""Execute the UDWF.

Expand Down
60 changes: 58 additions & 2 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@

import pyarrow as pa
import pytest
from datafusion import SessionContext, col, functions, lit
from datafusion import (
SessionContext,
col,
functions,
lit,
lit_with_metadata,
literal_with_metadata,
)
from datafusion.expr import (
Aggregate,
AggregateFunction,
Expand Down Expand Up @@ -103,7 +110,7 @@ def test_limit(test_ctx):

plan = plan.to_variant()
assert isinstance(plan, Limit)
assert "Skip: Some(Literal(Int64(5)))" in str(plan)
assert "Skip: Some(Literal(Int64(5), None))" in str(plan)


def test_aggregate_query(test_ctx):
Expand Down Expand Up @@ -824,3 +831,52 @@ def test_expr_functions(ctx, function, expected_result):

assert len(result) == 1
assert result[0].column(0).equals(expected_result)


def test_literal_metadata(ctx):
result = (
ctx.from_pydict({"a": [1]})
.select(
lit(1).alias("no_metadata"),
lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"),
literal_with_metadata(3, {"key2": "value2"}).alias(
"literal_with_metadata_fn"
),
)
.collect()
)

expected_schema = pa.schema(
[
pa.field("no_metadata", pa.int64(), nullable=False),
pa.field(
"lit_with_metadata_fn",
pa.int64(),
nullable=False,
metadata={"key1": "value1"},
),
pa.field(
"literal_with_metadata_fn",
pa.int64(),
nullable=False,
metadata={"key2": "value2"},
),
]
)

expected = pa.RecordBatch.from_pydict(
{
"no_metadata": pa.array([1]),
"lit_with_metadata_fn": pa.array([2]),
"literal_with_metadata_fn": pa.array([3]),
},
schema=expected_schema,
)

assert result[0] == expected

# Testing result[0].schema == expected_schema does not check each key/value pair
# so we want to explicitly test these
for expected_field in expected_schema:
actual_field = result[0].schema.field(expected_field.name)
assert expected_field.metadata == actual_field.metadata
7 changes: 5 additions & 2 deletions python/tests/test_wrapper_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from enum import EnumMeta as EnumType


def missing_exports(internal_obj, wrapped_obj) -> None:
def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901
"""
Identify if any of the rust exposted structs or functions do not have wrappers.

Special handling for:
- Raw* classes: Internal implementation details that shouldn't be exposed
- _global_ctx: Internal implementation detail
- __self__, __class__: Python special attributes
- __self__, __class__, __repr__: Python special attributes
"""
# Special case enums - EnumType overrides a some of the internal functions,
# so check all of the values exist and move on
Expand All @@ -45,6 +45,9 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
assert value in dir(wrapped_obj)
return

if "__repr__" in internal_obj.__dict__ and "__repr__" not in wrapped_obj.__dict__:
pytest.fail(f"Missing __repr__: {internal_obj.__name__}")

for internal_attr_name in dir(internal_obj):
wrapped_attr_name = internal_attr_name.removeprefix("Raw")
assert wrapped_attr_name in dir(wrapped_obj)
Expand Down
47 changes: 37 additions & 10 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::{
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::disk_manager::DiskManagerMode;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
Expand Down Expand Up @@ -183,22 +183,49 @@ impl PyRuntimeEnvBuilder {
}

fn with_disk_manager_disabled(&self) -> Self {
let mut builder = self.builder.clone();
builder = builder.with_disk_manager(DiskManagerConfig::Disabled);
Self { builder }
let mut runtime_builder = self.builder.clone();

let mut disk_mgr_builder = runtime_builder
.disk_manager_builder
.clone()
.unwrap_or_default();
disk_mgr_builder.set_mode(DiskManagerMode::Disabled);

runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
Self {
builder: runtime_builder,
}
}

fn with_disk_manager_os(&self) -> Self {
let builder = self.builder.clone();
let builder = builder.with_disk_manager(DiskManagerConfig::NewOs);
Self { builder }
let mut runtime_builder = self.builder.clone();

let mut disk_mgr_builder = runtime_builder
.disk_manager_builder
.clone()
.unwrap_or_default();
disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);

runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
Self {
builder: runtime_builder,
}
}

fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
let builder = self.builder.clone();
let paths = paths.iter().map(|s| s.into()).collect();
let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths));
Self { builder }
let mut runtime_builder = self.builder.clone();

let mut disk_mgr_builder = runtime_builder
.disk_manager_builder
.clone()
.unwrap_or_default();
disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));

runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
Self {
builder: runtime_builder,
}
}

fn with_unbounded_memory_pool(&self) -> Self {
Expand Down
Loading