Skip to content

Commit 9b6acec

Browse files
miclegrtrade6-bot
andauthored
Support types other than String and Int for partition columns (#1154)
* impl impl * fix test * format rust * support for old logic dasdas * also on io * fix formatting --------- Co-authored-by: michele gregori <[email protected]>
1 parent 1391078 commit 9b6acec

File tree

4 files changed

+132
-57
lines changed

4 files changed

+132
-57
lines changed

python/datafusion/context.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
from __future__ import annotations
2121

22+
import warnings
2223
from typing import TYPE_CHECKING, Any, Protocol
2324

25+
import pyarrow as pa
26+
2427
try:
2528
from warnings import deprecated # Python 3.13+
2629
except ImportError:
@@ -42,7 +45,6 @@
4245

4346
import pandas as pd
4447
import polars as pl
45-
import pyarrow as pa
4648

4749
from datafusion.plan import ExecutionPlan, LogicalPlan
4850

@@ -539,7 +541,7 @@ def register_listing_table(
539541
self,
540542
name: str,
541543
path: str | pathlib.Path,
542-
table_partition_cols: list[tuple[str, str]] | None = None,
544+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
543545
file_extension: str = ".parquet",
544546
schema: pa.Schema | None = None,
545547
file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -560,6 +562,7 @@ def register_listing_table(
560562
"""
561563
if table_partition_cols is None:
562564
table_partition_cols = []
565+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
563566
file_sort_order_raw = (
564567
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
565568
if file_sort_order is not None
@@ -778,7 +781,7 @@ def register_parquet(
778781
self,
779782
name: str,
780783
path: str | pathlib.Path,
781-
table_partition_cols: list[tuple[str, str]] | None = None,
784+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
782785
parquet_pruning: bool = True,
783786
file_extension: str = ".parquet",
784787
skip_metadata: bool = True,
@@ -806,6 +809,7 @@ def register_parquet(
806809
"""
807810
if table_partition_cols is None:
808811
table_partition_cols = []
812+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
809813
self.ctx.register_parquet(
810814
name,
811815
str(path),
@@ -869,7 +873,7 @@ def register_json(
869873
schema: pa.Schema | None = None,
870874
schema_infer_max_records: int = 1000,
871875
file_extension: str = ".json",
872-
table_partition_cols: list[tuple[str, str]] | None = None,
876+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
873877
file_compression_type: str | None = None,
874878
) -> None:
875879
"""Register a JSON file as a table.
@@ -890,6 +894,7 @@ def register_json(
890894
"""
891895
if table_partition_cols is None:
892896
table_partition_cols = []
897+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
893898
self.ctx.register_json(
894899
name,
895900
str(path),
@@ -906,7 +911,7 @@ def register_avro(
906911
path: str | pathlib.Path,
907912
schema: pa.Schema | None = None,
908913
file_extension: str = ".avro",
909-
table_partition_cols: list[tuple[str, str]] | None = None,
914+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
910915
) -> None:
911916
"""Register an Avro file as a table.
912917
@@ -922,6 +927,7 @@ def register_avro(
922927
"""
923928
if table_partition_cols is None:
924929
table_partition_cols = []
930+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
925931
self.ctx.register_avro(
926932
name, str(path), schema, file_extension, table_partition_cols
927933
)
@@ -981,7 +987,7 @@ def read_json(
981987
schema: pa.Schema | None = None,
982988
schema_infer_max_records: int = 1000,
983989
file_extension: str = ".json",
984-
table_partition_cols: list[tuple[str, str]] | None = None,
990+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
985991
file_compression_type: str | None = None,
986992
) -> DataFrame:
987993
"""Read a line-delimited JSON data source.
@@ -1001,6 +1007,7 @@ def read_json(
10011007
"""
10021008
if table_partition_cols is None:
10031009
table_partition_cols = []
1010+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10041011
return DataFrame(
10051012
self.ctx.read_json(
10061013
str(path),
@@ -1020,7 +1027,7 @@ def read_csv(
10201027
delimiter: str = ",",
10211028
schema_infer_max_records: int = 1000,
10221029
file_extension: str = ".csv",
1023-
table_partition_cols: list[tuple[str, str]] | None = None,
1030+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
10241031
file_compression_type: str | None = None,
10251032
) -> DataFrame:
10261033
"""Read a CSV data source.
@@ -1045,6 +1052,7 @@ def read_csv(
10451052
"""
10461053
if table_partition_cols is None:
10471054
table_partition_cols = []
1055+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10481056

10491057
path = [str(p) for p in path] if isinstance(path, list) else str(path)
10501058

@@ -1064,7 +1072,7 @@ def read_csv(
10641072
def read_parquet(
10651073
self,
10661074
path: str | pathlib.Path,
1067-
table_partition_cols: list[tuple[str, str]] | None = None,
1075+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
10681076
parquet_pruning: bool = True,
10691077
file_extension: str = ".parquet",
10701078
skip_metadata: bool = True,
@@ -1093,6 +1101,7 @@ def read_parquet(
10931101
"""
10941102
if table_partition_cols is None:
10951103
table_partition_cols = []
1104+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10961105
file_sort_order = (
10971106
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
10981107
if file_sort_order is not None
@@ -1114,7 +1123,7 @@ def read_avro(
11141123
self,
11151124
path: str | pathlib.Path,
11161125
schema: pa.Schema | None = None,
1117-
file_partition_cols: list[tuple[str, str]] | None = None,
1126+
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
11181127
file_extension: str = ".avro",
11191128
) -> DataFrame:
11201129
"""Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1130,6 +1139,7 @@ def read_avro(
11301139
"""
11311140
if file_partition_cols is None:
11321141
file_partition_cols = []
1142+
file_partition_cols = self._convert_table_partition_cols(file_partition_cols)
11331143
return DataFrame(
11341144
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
11351145
)
@@ -1146,3 +1156,41 @@ def read_table(self, table: Table) -> DataFrame:
11461156
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11471157
"""Execute the ``plan`` and return the results."""
11481158
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
1159+
1160+
@staticmethod
1161+
def _convert_table_partition_cols(
1162+
table_partition_cols: list[tuple[str, str | pa.DataType]],
1163+
) -> list[tuple[str, pa.DataType]]:
1164+
warn = False
1165+
converted_table_partition_cols = []
1166+
1167+
for col, data_type in table_partition_cols:
1168+
if isinstance(data_type, str):
1169+
warn = True
1170+
if data_type == "string":
1171+
converted_data_type = pa.string()
1172+
elif data_type == "int":
1173+
converted_data_type = pa.int32()
1174+
else:
1175+
message = (
1176+
f"Unsupported literal data type '{data_type}' for partition "
1177+
"column. Supported types are 'string' and 'int'"
1178+
)
1179+
raise ValueError(message)
1180+
else:
1181+
converted_data_type = data_type
1182+
1183+
converted_table_partition_cols.append((col, converted_data_type))
1184+
1185+
if warn:
1186+
message = (
1187+
"using literals for table_partition_cols data types is deprecated,"
1188+
"use pyarrow types instead"
1189+
)
1190+
warnings.warn(
1191+
message,
1192+
category=DeprecationWarning,
1193+
stacklevel=2,
1194+
)
1195+
1196+
return converted_table_partition_cols

python/datafusion/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
def read_parquet(
3636
path: str | pathlib.Path,
37-
table_partition_cols: list[tuple[str, str]] | None = None,
37+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
3838
parquet_pruning: bool = True,
3939
file_extension: str = ".parquet",
4040
skip_metadata: bool = True,
@@ -83,7 +83,7 @@ def read_json(
8383
schema: pa.Schema | None = None,
8484
schema_infer_max_records: int = 1000,
8585
file_extension: str = ".json",
86-
table_partition_cols: list[tuple[str, str]] | None = None,
86+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
8787
file_compression_type: str | None = None,
8888
) -> DataFrame:
8989
"""Read a line-delimited JSON data source.
@@ -124,7 +124,7 @@ def read_csv(
124124
delimiter: str = ",",
125125
schema_infer_max_records: int = 1000,
126126
file_extension: str = ".csv",
127-
table_partition_cols: list[tuple[str, str]] | None = None,
127+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
128128
file_compression_type: str | None = None,
129129
) -> DataFrame:
130130
"""Read a CSV data source.
@@ -171,7 +171,7 @@ def read_csv(
171171
def read_avro(
172172
path: str | pathlib.Path,
173173
schema: pa.Schema | None = None,
174-
file_partition_cols: list[tuple[str, str]] | None = None,
174+
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
175175
file_extension: str = ".avro",
176176
) -> DataFrame:
177177
"""Create a :py:class:`DataFrame` for reading Avro data source.

python/tests/test_sql.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path):
157157
assert result.to_pydict() == {"cnt": [100]}
158158

159159

160-
@pytest.mark.parametrize("path_to_str", [True, False])
161-
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
160+
@pytest.mark.parametrize(
161+
("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)]
162+
)
163+
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type):
162164
dir_root = tmp_path / "dataset_parquet_partitioned"
163165
dir_root.mkdir(exist_ok=False)
164166
(dir_root / "grp=a").mkdir(exist_ok=False)
@@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
177179

178180
dir_root = str(dir_root) if path_to_str else dir_root
179181

182+
partition_data_type = "string" if legacy_data_type else pa.string()
183+
180184
ctx.register_parquet(
181185
"datapp",
182186
dir_root,
183-
table_partition_cols=[("grp", "string")],
187+
table_partition_cols=[("grp", partition_data_type)],
184188
parquet_pruning=True,
185189
file_extension=".parquet",
186190
)
@@ -488,9 +492,9 @@ def test_register_listing_table(
488492
):
489493
dir_root = tmp_path / "dataset_parquet_partitioned"
490494
dir_root.mkdir(exist_ok=False)
491-
(dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
492-
(dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
493-
(dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
495+
(dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True)
496+
(dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True)
497+
(dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True)
494498

495499
table = pa.Table.from_arrays(
496500
[
@@ -501,21 +505,21 @@ def test_register_listing_table(
501505
names=["int", "str", "float"],
502506
)
503507
pa.parquet.write_table(
504-
table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
508+
table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet"
505509
)
506510
pa.parquet.write_table(
507-
table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
511+
table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet"
508512
)
509513
pa.parquet.write_table(
510-
table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
514+
table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet"
511515
)
512516

513517
dir_root = f"file://{dir_root}/" if path_to_str else dir_root
514518

515519
ctx.register_listing_table(
516520
"my_table",
517521
dir_root,
518-
table_partition_cols=[("grp", "string"), ("date_id", "int")],
522+
table_partition_cols=[("grp", pa.string()), ("date", pa.date64())],
519523
file_extension=".parquet",
520524
schema=table.schema if pass_schema else None,
521525
file_sort_order=file_sort_order,
@@ -531,7 +535,7 @@ def test_register_listing_table(
531535
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
532536

533537
result = ctx.sql(
534-
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501
538+
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501
535539
).collect()
536540
result = pa.Table.from_batches(result)
537541

0 commit comments

Comments
 (0)