19
19
20
20
from __future__ import annotations
21
21
22
+ import warnings
22
23
from typing import TYPE_CHECKING , Any , Protocol
23
24
25
+ import pyarrow as pa
26
+
24
27
try :
25
28
from warnings import deprecated # Python 3.13+
26
29
except ImportError :
42
45
43
46
import pandas as pd
44
47
import polars as pl
45
- import pyarrow as pa
46
48
47
49
from datafusion .plan import ExecutionPlan , LogicalPlan
48
50
@@ -539,7 +541,7 @@ def register_listing_table(
539
541
self ,
540
542
name : str ,
541
543
path : str | pathlib .Path ,
542
- table_partition_cols : list [tuple [str , str ]] | None = None ,
544
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
543
545
file_extension : str = ".parquet" ,
544
546
schema : pa .Schema | None = None ,
545
547
file_sort_order : list [list [Expr | SortExpr ]] | None = None ,
@@ -560,6 +562,7 @@ def register_listing_table(
560
562
"""
561
563
if table_partition_cols is None :
562
564
table_partition_cols = []
565
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
563
566
file_sort_order_raw = (
564
567
[sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
565
568
if file_sort_order is not None
@@ -778,7 +781,7 @@ def register_parquet(
778
781
self ,
779
782
name : str ,
780
783
path : str | pathlib .Path ,
781
- table_partition_cols : list [tuple [str , str ]] | None = None ,
784
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
782
785
parquet_pruning : bool = True ,
783
786
file_extension : str = ".parquet" ,
784
787
skip_metadata : bool = True ,
@@ -806,6 +809,7 @@ def register_parquet(
806
809
"""
807
810
if table_partition_cols is None :
808
811
table_partition_cols = []
812
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
809
813
self .ctx .register_parquet (
810
814
name ,
811
815
str (path ),
@@ -869,7 +873,7 @@ def register_json(
869
873
schema : pa .Schema | None = None ,
870
874
schema_infer_max_records : int = 1000 ,
871
875
file_extension : str = ".json" ,
872
- table_partition_cols : list [tuple [str , str ]] | None = None ,
876
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
873
877
file_compression_type : str | None = None ,
874
878
) -> None :
875
879
"""Register a JSON file as a table.
@@ -890,6 +894,7 @@ def register_json(
890
894
"""
891
895
if table_partition_cols is None :
892
896
table_partition_cols = []
897
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
893
898
self .ctx .register_json (
894
899
name ,
895
900
str (path ),
@@ -906,7 +911,7 @@ def register_avro(
906
911
path : str | pathlib .Path ,
907
912
schema : pa .Schema | None = None ,
908
913
file_extension : str = ".avro" ,
909
- table_partition_cols : list [tuple [str , str ]] | None = None ,
914
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
910
915
) -> None :
911
916
"""Register an Avro file as a table.
912
917
@@ -922,6 +927,7 @@ def register_avro(
922
927
"""
923
928
if table_partition_cols is None :
924
929
table_partition_cols = []
930
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
925
931
self .ctx .register_avro (
926
932
name , str (path ), schema , file_extension , table_partition_cols
927
933
)
@@ -981,7 +987,7 @@ def read_json(
981
987
schema : pa .Schema | None = None ,
982
988
schema_infer_max_records : int = 1000 ,
983
989
file_extension : str = ".json" ,
984
- table_partition_cols : list [tuple [str , str ]] | None = None ,
990
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
985
991
file_compression_type : str | None = None ,
986
992
) -> DataFrame :
987
993
"""Read a line-delimited JSON data source.
@@ -1001,6 +1007,7 @@ def read_json(
1001
1007
"""
1002
1008
if table_partition_cols is None :
1003
1009
table_partition_cols = []
1010
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
1004
1011
return DataFrame (
1005
1012
self .ctx .read_json (
1006
1013
str (path ),
@@ -1020,7 +1027,7 @@ def read_csv(
1020
1027
delimiter : str = "," ,
1021
1028
schema_infer_max_records : int = 1000 ,
1022
1029
file_extension : str = ".csv" ,
1023
- table_partition_cols : list [tuple [str , str ]] | None = None ,
1030
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1024
1031
file_compression_type : str | None = None ,
1025
1032
) -> DataFrame :
1026
1033
"""Read a CSV data source.
@@ -1045,6 +1052,7 @@ def read_csv(
1045
1052
"""
1046
1053
if table_partition_cols is None :
1047
1054
table_partition_cols = []
1055
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
1048
1056
1049
1057
path = [str (p ) for p in path ] if isinstance (path , list ) else str (path )
1050
1058
@@ -1064,7 +1072,7 @@ def read_csv(
1064
1072
def read_parquet (
1065
1073
self ,
1066
1074
path : str | pathlib .Path ,
1067
- table_partition_cols : list [tuple [str , str ]] | None = None ,
1075
+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1068
1076
parquet_pruning : bool = True ,
1069
1077
file_extension : str = ".parquet" ,
1070
1078
skip_metadata : bool = True ,
@@ -1093,6 +1101,7 @@ def read_parquet(
1093
1101
"""
1094
1102
if table_partition_cols is None :
1095
1103
table_partition_cols = []
1104
+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
1096
1105
file_sort_order = (
1097
1106
[sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
1098
1107
if file_sort_order is not None
@@ -1114,7 +1123,7 @@ def read_avro(
1114
1123
self ,
1115
1124
path : str | pathlib .Path ,
1116
1125
schema : pa .Schema | None = None ,
1117
- file_partition_cols : list [tuple [str , str ]] | None = None ,
1126
+ file_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1118
1127
file_extension : str = ".avro" ,
1119
1128
) -> DataFrame :
1120
1129
"""Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1130,6 +1139,7 @@ def read_avro(
1130
1139
"""
1131
1140
if file_partition_cols is None :
1132
1141
file_partition_cols = []
1142
+ file_partition_cols = self ._convert_table_partition_cols (file_partition_cols )
1133
1143
return DataFrame (
1134
1144
self .ctx .read_avro (str (path ), schema , file_partition_cols , file_extension )
1135
1145
)
@@ -1146,3 +1156,41 @@ def read_table(self, table: Table) -> DataFrame:
1146
1156
def execute (self , plan : ExecutionPlan , partitions : int ) -> RecordBatchStream :
1147
1157
"""Execute the ``plan`` and return the results."""
1148
1158
return RecordBatchStream (self .ctx .execute (plan ._raw_plan , partitions ))
1159
+
1160
+ @staticmethod
1161
+ def _convert_table_partition_cols (
1162
+ table_partition_cols : list [tuple [str , str | pa .DataType ]],
1163
+ ) -> list [tuple [str , pa .DataType ]]:
1164
+ warn = False
1165
+ converted_table_partition_cols = []
1166
+
1167
+ for col , data_type in table_partition_cols :
1168
+ if isinstance (data_type , str ):
1169
+ warn = True
1170
+ if data_type == "string" :
1171
+ converted_data_type = pa .string ()
1172
+ elif data_type == "int" :
1173
+ converted_data_type = pa .int32 ()
1174
+ else :
1175
+ message = (
1176
+ f"Unsupported literal data type '{ data_type } ' for partition "
1177
+ "column. Supported types are 'string' and 'int'"
1178
+ )
1179
+ raise ValueError (message )
1180
+ else :
1181
+ converted_data_type = data_type
1182
+
1183
+ converted_table_partition_cols .append ((col , converted_data_type ))
1184
+
1185
+ if warn :
1186
+ message = (
1187
+ "using literals for table_partition_cols data types is deprecated,"
1188
+ "use pyarrow types instead"
1189
+ )
1190
+ warnings .warn (
1191
+ message ,
1192
+ category = DeprecationWarning ,
1193
+ stacklevel = 2 ,
1194
+ )
1195
+
1196
+ return converted_table_partition_cols
0 commit comments