8000 Default to ZSTD compression when writing Parquet (#981) · chenkovsky/datafusion-python@2d8b1d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d8b1d3

Browse files
kosiewtimsaucer
andauthored
Default to ZSTD compression when writing Parquet (apache#981)
* fix: update default compression to ZSTD and improve documentation for write_parquet method * fix: clarify compression level documentation for ZSTD in write_parquet method * fix: update default compression level for ZSTD to 4 in write_parquet method * fix: improve docstring formatting for DataFrame parquet writing method * feat: implement Compression enum and update write_parquet method to use it * add test * fix: remove unused import and update default compression to ZSTD in rs' write_parquet method * fix: update compression type strings to lowercase in DataFrame parquet writing method doc * test: update parquet compression tests to validate invalid and default compression levels * add comment on source of Compression * docs: enhance Compression enum documentation and add default level method * test: include gzip in default compression level tests for write_parquet * refactor: simplify Compression enum methods and improve type handling in DataFrame.write_parquet * docs: update Compression enum methods to include return type descriptions * move comment to within test * Ruff format --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent db1bc62 commit 2d8b1d3

File tree

3 files changed

+101
-9
lines changed

3 files changed

+101
-9
lines changed

python/datafusion/dataframe.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@
2121

2222
from __future__ import annotations
2323
import warnings
24-
from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload
24+
from typing import (
25+
Any,
26+
Iterable,
27+
List,
28+
TYPE_CHECKING,
29+
Literal,
30+
overload,
31+
Optional,
32+
Union,
33+
)
2534
from datafusion.record_batch import RecordBatchStream
2635
from typing_extensions import deprecated
2736
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -35,6 +44,60 @@
3544

3645
from datafusion._internal import DataFrame as DataFrameInternal
3746
from datafusion.expr import Expr, SortExpr, sort_or_default
47+
from enum import Enum
48+
49+
50+
# excerpt from deltalake
51+
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
52+
class Compression(Enum):
53+
"""Enum representing the available compression types for Parquet files."""
54+
55+
UNCOMPRESSED = "uncompressed"
56+
SNAPPY = "snappy"
57+
GZIP = "gzip"
58+
BROTLI = "brotli"
59+
LZ4 = "lz4"
60+
LZ0 = "lz0"
61+
ZSTD = "zstd"
62+
LZ4_RAW = "lz4_raw"
63+
64+
@classmethod
65+
def from_str(cls, value: str) -> "Compression":
66+
"""Convert a string to a Compression enum value.
67+
68+
Args:
69+
value: The string representation of the compression type.
70+
71+
Returns:
72+
The Compression enum lowercase value.
73+
74+
Raises:
75+
ValueError: If the string does not match any Compression enum value.
76+
"""
77+
try:
78+
return cls(value.lower())
79+
except ValueError:
80+
raise ValueError(
81+
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
82+
)
83+
84+
def get_default_level(self) -> Optional[int]:
85+
"""Get the default compression level for the compression type.
86+
87+
Returns:
88+
The default compression level for the compression type.
89+
"""
90+
# GZIP, BROTLI default values from deltalake repo
91+
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
92+
# ZSTD default value from delta-rs
93+
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
94+
if self == Compression.GZIP:
95+
return 6
96+
elif self == Compression.BROTLI:
97+
return 1
98+
elif self == Compression.ZSTD:
99+
return 4
100+
return None
38101

39102

40103
class DataFrame:
@@ -620,17 +683,36 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
620683
def write_parquet(
621684
self,
622685
path: str | pathlib.Path,
623-
compression: str = "uncompressed",
686+
compression: Union[str, Compression] = Compression.ZSTD,
624687
compression_level: int | None = None,
625688
) -> None:
626689
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
627690
628691
Args:
629692
path: Path of the Parquet file to write.
630-
compression: Compression type to use.
631-
compression_level: Compression level to use.
632-
"""
633-
self.df.write_parquet(str(path), compression, compression_level)
693+
compression: Compression type to use. Default is "ZSTD".
694+
Available compression types are:
695+
- "uncompressed": No compression.
696+
- "snappy": Snappy compression.
697+
- "gzip": Gzip compression.
698+
- "brotli": Brotli compression.
699+
- "lz0": LZ0 compression.
700+
- "lz4": LZ4 compression.
701+
- "lz4_raw": LZ4_RAW compression.
702+
- "zstd": Zstandard compression.
703+
compression_level: Compression level to use. For ZSTD, the
704+
recommended range is 1 to 22, with the default being 4. Higher levels
705+
provide better compression but slower speed.
706+
"""
707+
# Convert string to Compression enum if necessary
708+
if isinstance(compression, str):
709+
compression = Compression.from_str(compression)
710+
711+
if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
712+
if compression_level is None:
713+
compression_level = compression.get_default_level()
714+
715+
self.df.write_parquet(str(path), compression.value, compression_level)
634716

635717
def write_json(self, path: str | pathlib.Path) -> None:
636718
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.

python/tests/test_dataframe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,14 +1107,24 @@ def test_write_compressed_parquet_wrong_compression_level(
11071107
)
11081108

11091109

1110-
@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
1111-
def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression):
1110+
@pytest.mark.parametrize("compression", ["wrong"])
1111+
def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression):
11121112
path = tmp_path
11131113

11141114
with pytest.raises(ValueError):
11151115
df.write_parquet(str(path), compression=compression)
11161116

11171117

1118+
@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"])
1119+
def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression):
1120+
# Test write_parquet with zstd, brotli, gzip default compression level,
1121+
# ie don't specify compression level
1122+
# should complete without error
1123+
path = tmp_path
1124+
1125+
df.write_parquet(str(path), compression=compression)
1126+
1127+
11181128
def test_dataframe_export(df) -> None:
11191129
# Guarantees that we have the canonical implementation
11201130
# reading our dataframe export

src/dataframe.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ impl PyDataFrame {
463463
/// Write a `DataFrame` to a Parquet file.
464464
#[pyo3(signature = (
465465
path,
466-
compression="uncompressed",
466+
compression="zstd",
467467
compression_level=None
468468
))]
469469
fn write_parquet(

0 commit comments

Comments
 (0)
0