8000 feat: add basic compression configuration to write_parquet (#459) · llama90/arrow-datafusion-python@499f045 · GitHub
[go: up one dir, main page]

Skip to content

Commit 499f045

Browse files
authored
feat: add basic compression configuration to write_parquet (apache#459)
1 parent 217ede8 commit 499f045

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

datafusion/tests/test_dataframe.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import os
1718

1819
import pyarrow as pa
20+
import pyarrow.parquet as pq
1921
import pytest
2022

2123
from datafusion import functions as f
@@ -645,3 +647,68 @@ def test_describe(df):
645647
"b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0],
646648
"c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0],
647649
}
650+
651+
652+
def test_write_parquet(df, tmp_path):
653+
path = tmp_path
654+
655+
df.write_parquet(str(path))
656+
result = pq.read_table(str(path)).to_pydict()
657+
expected = df.to_pydict()
658+
659+
assert result == expected
660+
661+
662+
@pytest.mark.parametrize(
663+
"compression, compression_level",
664+
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
665+
)
666+
def test_write_compressed_parquet(
667+
df, tmp_path, compression, compression_level
668+
):
669+
path = tmp_path
670+
671+
df.write_parquet(
672+
str(path), compression=compression, compression_level=compression_level
673+
)
674+
675+
# test that the actual compression scheme is the one written
676+
for root, dirs, files in os.walk(path):
677+
for file in files:
678+
if file.endswith(".parquet"):
679+
metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
680+
for row_group in metadata["row_groups"]:
681+
for columns in row_group["columns"]:
682+
assert columns["compression"].lower() == compression
683+
684+
result = pq.read_table(str(path)).to_pydict()
685+
expected = df.to_pydict()
686+
687+
assert result == expected
688+
689+
690+
@pytest.mark.parametrize(
691+
"compression, compression_level",
692+
[("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
693+
)
694+
def test_write_compressed_parquet_wrong_compression_level(
695+
df, tmp_path, compression, compression_level
696+
):
697+
path = tmp_path
698+
699+
with pytest.raises(ValueError):
700+
df.write_parquet(
701+
str(path),
702+
compression=compression,
703+
compression_level=compression_level,
704+
)
705+
706+
707+
@pytest.mark. 8000 parametrize("compression", ["brotli", "zstd", "wrong"])
708+
def test_write_compressed_parquet_missing_compression_level(
709+
df, tmp_path, compression
710+
):
711+
path = tmp_path
712+
713+
with pytest.raises(ValueError):
714+
df.write_parquet(str(path), compression=compression)

src/dataframe.rs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ use datafusion::arrow::datatypes::Schema;
2323
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2424
use datafusion::arrow::util::pretty;
2525
use datafusion::dataframe::DataFrame;
26+
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
27+
use datafusion::parquet::file::properties::WriterProperties;
2628
use datafusion::prelude::*;
27-
use pyo3::exceptions::PyTypeError;
29+
use pyo3::exceptions::{PyTypeError, PyValueError};
2830
use pyo3::prelude::*;
2931
use pyo3::types::PyTuple;
3032
use std::sync::Arc;
@@ -308,8 +310,58 @@ impl PyDataFrame {
308310
}
309311

310312
/// Write a `DataFrame` to a Parquet file.
311-
fn write_parquet(&self, path: &str, py: Python) -> PyResult<()> {
312-
wait_for_future(py, self.df.as_ref().clone().write_parquet(path, None))?;
313+
#[pyo3(signature = (
314+
path,
315+
compression="uncompressed",
316+
compression_level=None
317+
))]
318+
fn write_parquet(
319+
&self,
320+
path: &str,
321+
compression: &str,
322+
compression_level: Option<u32>,
323+
py: Python,
324+
) -> PyResult<()> {
325+
fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
326+
cl.ok_or(PyValueError::new_err("compression_level is not defined"))
327+
}
328+
329+
let compression_type = match compression.to_lowercase().as_str() {
330+
"snappy" => Compression::SNAPPY,
331+
"gzip" => Compression::GZIP(
332+
GzipLevel::try_new(compression_level.unwrap_or(6))
333+
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
334+
),
335+
"brotli" => Compression::BROTLI(
336+
BrotliLevel::try_new(verify_compression_level(compression_level)?)
337+
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
338+
),
339+
"zstd" => Compression::ZSTD(
340+
ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
341+
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
342+
),
343+
"lz0" => Compression::LZO,
344+
"lz4" => Compression::LZ4,
345+
"lz4_raw" => Compression::LZ4_RAW,
346+
"uncompressed" => Compression::UNCOMPRESSED,
347+
_ => {
348+
return Err(PyValueError::new_err(format!(
349+
"Unrecognized compression type {compression}"
350+
)));
351+
}
352+
};
353+
354+
let writer_properties = WriterProperties::builder()
355+
.set_compression(compression_type)
356+
.build();
357+
358+
wait_for_future(
359+
py,
360+
self.df
361+
.as_ref()
362+
.clone()
363+
.write_parquet(path, Option::from(writer_properties)),
364+
)?;
313365
Ok(())
314366
}
315367

0 commit comments

Comments
 (0)
0