|
21 | 21 |
|
22 | 22 | from __future__ import annotations
|
23 | 23 | import warnings
|
24 |
| -from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload |
| 24 | +from typing import ( |
| 25 | + Any, |
| 26 | + Iterable, |
| 27 | + List, |
| 28 | + TYPE_CHECKING, |
| 29 | + Literal, |
| 30 | + overload, |
| 31 | + Optional, |
| 32 | + Union, |
| 33 | +) |
25 | 34 | from datafusion.record_batch import RecordBatchStream
|
26 | 35 | from typing_extensions import deprecated
|
27 | 36 | from datafusion.plan import LogicalPlan, ExecutionPlan
|
|
35 | 44 |
|
36 | 45 | from datafusion._internal import DataFrame as DataFrameInternal
|
37 | 46 | from datafusion.expr import Expr, SortExpr, sort_or_default
|
| 47 | +from enum import Enum |
| 48 | + |
| 49 | + |
| 50 | +# excerpt from deltalake |
| 51 | +# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 |
| 52 | +class Compression(Enum): |
| 53 | + """Enum representing the available compression types for Parquet files.""" |
| 54 | + |
| 55 | + UNCOMPRESSED = "uncompressed" |
| 56 | + SNAPPY = "snappy" |
| 57 | + GZIP = "gzip" |
| 58 | + BROTLI = "brotli" |
| 59 | + LZ4 = "lz4" |
| 60 | + LZ0 = "lz0" |
| 61 | + ZSTD = "zstd" |
| 62 | + LZ4_RAW = "lz4_raw" |
| 63 | + |
| 64 | + @classmethod |
| 65 | + def from_str(cls, value: str) -> "Compression": |
| 66 | + """Convert a string to a Compression enum value. |
| 67 | +
|
| 68 | + Args: |
| 69 | + value: The string representation of the compression type. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + The Compression enum lowercase value. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + ValueError: If the string does not match any Compression enum value. |
| 76 | + """ |
| 77 | + try: |
| 78 | + return cls(value.lower()) |
| 79 | + except ValueError: |
| 80 | + raise ValueError( |
| 81 | + f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}" |
| 82 | + ) |
| 83 | + |
| 84 | + def get_default_level(self) -> Optional[int]: |
| 85 | + """Get the default compression level for the compression type. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + The default compression level for the compression type. |
| 89 | + """ |
| 90 | + # GZIP, BROTLI default values from deltalake repo |
| 91 | + # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 |
| 92 | + # ZSTD default value from delta-rs |
| 93 | + # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223 |
| 94 | + if self == Compression.GZIP: |
| 95 | + return 6 |
| 96 | + elif self == Compression.BROTLI: |
| 97 | + return 1 |
| 98 | + elif self == Compression.ZSTD: |
| 99 | + return 4 |
| 100 | + return None |
38 | 101 |
|
39 | 102 |
|
40 | 103 | class DataFrame:
|
@@ -620,17 +683,36 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
|
620 | 683 | def write_parquet(
|
621 | 684 | self,
|
622 | 685 | path: str | pathlib.Path,
|
623 |
| - compression: str = "uncompressed", |
| 686 | + compression: Union[str, Compression] = Compression.ZSTD, |
624 | 687 | compression_level: int | None = None,
|
625 | 688 | ) -> None:
|
626 | 689 | """Execute the :py:class:`DataFrame` and write the results to a Parquet file.
|
627 | 690 |
|
628 | 691 | Args:
|
629 | 692 | path: Path of the Parquet file to write.
|
630 |
| - compression: Compression type to use. |
631 |
| - compression_level: Compression level to use. |
632 |
| - """ |
633 |
| - self.df.write_parquet(str(path), compression, compression_level) |
| 693 | + compression: Compression type to use. Default is "ZSTD". |
| 694 | + Available compression types are: |
| 695 | + - "uncompressed": No compression. |
| 696 | + - "snappy": Snappy compression. |
| 697 | + - "gzip": Gzip compression. |
| 698 | + - "brotli": Brotli compression. |
| 699 | + - "lz0": LZ0 compression. |
| 700 | + - "lz4": LZ4 compression. |
| 701 | + - "lz4_raw": LZ4_RAW compression. |
| 702 | + - "zstd": Zstandard compression. |
| 703 | + compression_level: Compression level to use. For ZSTD, the |
| 704 | + recommended range is 1 to 22, with the default being 4. Higher levels |
| 705 | + provide better compression but slower speed. |
| 706 | + """ |
| 707 | + # Convert string to Compression enum if necessary |
| 708 | + if isinstance(compression, str): |
| 709 | + compression = Compression.from_str(compression) |
| 710 | + |
| 711 | + if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}: |
| 712 | + if compression_level is None: |
| 713 | + compression_level = compression.get_default_level() |
| 714 | + |
| 715 | + self.df.write_parquet(str(path), compression.value, compression_level) |
634 | 716 |
|
635 | 717 | def write_json(self, path: str | pathlib.Path) -> None:
|
636 | 718 | """Execute the :py:class:`DataFrame` and write the results to a JSON file.
|
|
0 commit comments