Skip to content

Commit 5a5eb39

Browse files
committed
feat: Add Parquet writer option autodetection
1 parent 98dc06b commit 5a5eb39

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

python/datafusion/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
SessionContext,
4747
SQLOptions,
4848
)
49-
from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions
49+
from .dataframe import (
50+
DataFrame,
51+
ParquetColumnOptions,
52+
ParquetWriterOptions,
53+
)
5054
from .expr import (
5155
Expr,
5256
WindowFrame,

python/datafusion/dataframe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from datafusion._internal import DataFrame as DataFrameInternal
5656
from datafusion._internal import expr as expr_internal
5757

58+
from dataclasses import dataclass
5859
from enum import Enum
5960

6061

@@ -873,7 +874,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
873874
def write_parquet(
874875
self,
875876
path: str | pathlib.Path,
876-
compression: Union[str, Compression] = Compression.ZSTD,
877+
compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD,
877878
compression_level: int | None = None,
878879
) -> None:
879880
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
@@ -894,7 +895,13 @@ def write_parquet(
894895
recommended range is 1 to 22, with the default being 4. Higher levels
895896
provide better compression but slower speed.
896897
"""
897-
# Convert string to Compression enum if necessary
898+
if isinstance(compression, ParquetWriterOptions):
899+
if compression_level is not None:
900+
msg = "compression_level should be None when using ParquetWriterOptions"
901+
raise ValueError(msg)
902+
self.write_parquet_with_options(path, compression)
903+
return
904+
898905
if isinstance(compression, str):
899906
compression = Compression.from_str(compression)
900907

python/tests/test_dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,22 @@ def test_write_parquet_with_options_column_options(df, tmp_path):
20382038
assert col["encodings"] == result["encodings"]
20392039

20402040

2041+
def test_write_parquet_options(df, tmp_path):
2042+
options = ParquetWriterOptions(compression="gzip", compression_level=6)
2043+
df.write_parquet(str(tmp_path), options)
2044+
2045+
result = pq.read_table(str(tmp_path)).to_pydict()
2046+
expected = df.to_pydict()
2047+
2048+
assert result == expected
2049+
2050+
2051+
def test_write_parquet_options_error(df, tmp_path):
2052+
options = ParquetWriterOptions(compression="gzip", compression_level=6)
2053+
with pytest.raises(ValueError):
2054+
df.write_parquet(str(tmp_path), options, compression_level=1)
2055+
2056+
20412057
def test_dataframe_export(df) -> None:
20422058
# Guarantees that we have the canonical implementation
20432059
# reading our dataframe export

0 commit comments

Comments
 (0)