From fdb713a83af3e969bddebaf0536315475757c687 Mon Sep 17 00:00:00 2001 From: Maren van Otterdijk Date: Wed, 29 Jan 2025 01:51:23 +0100 Subject: [PATCH] expose DataFrameWriteOptions for csv, json and parquet --- python/datafusion/dataframe.py | 60 +++++++++++++++++++++++---- python/datafusion/options.py | 47 +++++++++++++++++++++ src/dataframe.rs | 56 ++++++++++++++++++------- src/lib.rs | 1 + src/options.rs | 74 ++++++++++++++++++++++++++++++++++ 5 files changed, 216 insertions(+), 22 deletions(-) create mode 100644 python/datafusion/options.py create mode 100644 src/options.rs diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 7413a5fa3..b0e8ccd00 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -27,6 +27,7 @@ Any, Iterable, List, + Dict, Literal, Optional, Union, @@ -50,6 +51,7 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr, SortExpr, sort_or_default +from datafusion.options import write_options_to_raw_write_options # excerpt from deltalake @@ -678,20 +680,37 @@ def except_all(self, other: DataFrame) -> DataFrame: """ return DataFrame(self.df.except_all(other.df)) - def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None: + def write_csv( + self, + path: str | pathlib.Path, + with_header: bool = False, + write_options: Optional[Dict] = None, + ) -> None: """Execute the :py:class:`DataFrame` and write the results to a CSV file. Args: path: Path of the CSV file to write. with_header: If true, output the CSV header row. - """ - self.df.write_csv(str(path), with_header) + write_options: Write options to use. This is a dictionary. + Available options are: + - "insert_operation": one of + - "append": Appends new rows to the existing table without modifying any existing rows. This corresponds to the SQL INSERT INTO query. + - "overwrite": Overwrites all existing rows in the table with the new rows. This corresponds to the SQL INSERT OVERWRITE query. + - "replace": If any existing rows collides with the inserted rows (typically based on a unique key or primary key), those existing rows are replaced. This corresponds to the SQL REPLACE INTO query and its equivalents. + - "single_file_output": bool expressing if the write should go into a single file or not. + - "partition_by": a list of column names (as strings) to set up hive partitioning. + - "sort_by": a list of sort expressions to sort the output by. + """ + self.df.write_csv( + str(path), with_header, write_options_to_raw_write_options(write_options) + ) def write_parquet( self, path: str | pathlib.Path, compression: Union[str, Compression] = Compression.ZSTD, compression_level: int | None = None, + write_options: Optional[Dict] = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. @@ -710,6 +729,15 @@ def write_parquet( compression_level: Compression level to use. For ZSTD, the recommended range is 1 to 22, with the default being 4. Higher levels provide better compression but slower speed. + write_options: Write options to use. This is a dictionary. + Available options are: + - "insert_operation": one of + - "append": Appends new rows to the existing table without modifying any existing rows. This corresponds to the SQL INSERT INTO query. + - "overwrite": Overwrites all existing rows in the table with the new rows. This corresponds to the SQL INSERT OVERWRITE query. + - "replace": If any existing rows collides with the inserted rows (typically based on a unique key or primary key), those existing rows are replaced. This corresponds to the SQL REPLACE INTO query and its equivalents. + - "single_file_output": bool expressing if the write should go into a single file or not. + - "partition_by": a list of column names (as strings) to set up hive partitioning. + - "sort_by": a list of sort expressions to sort the output by. """ # Convert string to Compression enum if necessary if isinstance(compression, str): @@ -719,15 +747,33 @@ def write_parquet( if compression_level is None: compression_level = compression.get_default_level() - self.df.write_parquet(str(path), compression.value, compression_level) + self.df.write_parquet( + str(path), + compression.value, + compression_level, + write_options_to_raw_write_options(write_options), + ) - def write_json(self, path: str | pathlib.Path) -> None: + def write_json( + self, + path: str | pathlib.Path, + write_options: Optional[Dict] = None, + ) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. Args: path: Path of the JSON file to write. - """ - self.df.write_json(str(path)) + write_options: Write options to use. This is a dictionary. + Available options are: + - "insert_operation": one of + - "append": Appends new rows to the existing table without modifying any existing rows. This corresponds to the SQL INSERT INTO query. + - "overwrite": Overwrites all existing rows in the table with the new rows. This corresponds to the SQL INSERT OVERWRITE query. + - "replace": If any existing rows collides with the inserted rows (typically based on a unique key or primary key), those existing rows are replaced. This corresponds to the SQL REPLACE INTO query and its equivalents. + - "single_file_output": bool expressing if the write should go into a single file or not. + - "partition_by": a list of column names (as strings) to set up hive partitioning. + - "sort_by": a list of sort expressions to sort the output by. + """ + self.df.write_json(str(path), write_options_to_raw_write_options(write_options)) def to_arrow_table(self) -> pa.Table: """Execute the :py:class:`DataFrame` and convert it into an Arrow Table. diff --git a/python/datafusion/options.py b/python/datafusion/options.py new file mode 100644 index 000000000..a6d44f836 --- /dev/null +++ b/python/datafusion/options.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Option conversion functions.""" + +from typing import Optional, Dict +from datafusion.expr import sort_list_to_raw_sort_list + + +def write_options_to_raw_write_options(write_options: Optional[Dict]) -> Dict: + """Convert a dictionary of write options into the format expected by the pyo3 bindings. + Validates that no superfluous keys are specified, then: + - adds default keys for each expected write option. + - converts "sort_by" into a raw sort list. + """ + defaults = { + "insert_operation": None, + "single_file_output": None, + "partition_by": None, + "sort_by": None, + } + + if write_options is not None: + invalid_write_options = set(write_options) - set(defaults) + if invalid_write_options: + raise ValueError(f"Invalid write options: {invalid_write_options}") + + results = {**defaults, **write_options} + if "sort_by" in write_options: + results["sort_by"] = sort_list_to_raw_sort_list(write_options["sort_by"]) + + return results + else: + return defaults diff --git a/src/dataframe.rs b/src/dataframe.rs index b875480a7..0bd0ade26 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -29,7 +29,7 @@ use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; use datafusion::config::{CsvOptions, TableParquetOptions}; -use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; +use datafusion::dataframe::DataFrame; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; @@ -41,6 +41,7 @@ use tokio::task::JoinHandle; use crate::errors::py_datafusion_err; use crate::expr::sort_expr::to_sort_expressions; +use crate::options::{make_dataframe_write_options, PyDataFrameWriteOptions}; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; @@ -444,18 +445,29 @@ impl PyDataFrame { } /// Write a `DataFrame` to a CSV file. - fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyResult<()> { + #[pyo3(signature = ( + path, + with_header, + write_options=None, + ))] + fn write_csv( + &self, + path: &str, + with_header: bool, + write_options: Option, + py: Python, + ) -> PyResult<()> { + let write_options = make_dataframe_write_options(write_options)?; let csv_options = CsvOptions { has_header: Some(with_header), ..Default::default() }; wait_for_future( py, - self.df.as_ref().clone().write_csv( - path, - DataFrameWriteOptions::new(), - Some(csv_options), - ), + self.df + .as_ref() + .clone() + .write_csv(path, write_options, Some(csv_options)), )?; Ok(()) } @@ -464,13 +476,15 @@ impl PyDataFrame { #[pyo3(signature = ( path, compression="zstd", - compression_level=None + compression_level=None, + write_options=None, ))] fn write_parquet( &self, path: &str, compression: &str, compression_level: Option, + write_options: Option, py: Python, ) -> PyResult<()> { fn verify_compression_level(cl: Option) -> Result { @@ -510,25 +524,37 @@ impl PyDataFrame { let mut options = TableParquetOptions::default(); options.global.compression = Some(compression_string); + let write_options = make_dataframe_write_options(write_options)?; + wait_for_future( py, - self.df.as_ref().clone().write_parquet( - path, - DataFrameWriteOptions::new(), - Option::from(options), - ), + self.df + .as_ref() + .clone() + .write_parquet(path, write_options, Option::from(options)), )?; Ok(()) } /// Executes a query and writes the results to a partitioned JSON file. - fn write_json(&self, path: &str, py: Python) -> PyResult<()> { + #[pyo3(signature = ( + path, + write_options=None, + ))] + fn write_json( + &self, + path: &str, + write_options: Option, + py: Python, + ) -> PyResult<()> { + let write_options = make_dataframe_write_options(write_options)?; + wait_for_future( py, self.df .as_ref() .clone() - .write_json(path, DataFrameWriteOptions::new(), None), + .write_json(path, write_options, None), )?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 1111d5d06..60a075cda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ mod record_batch; pub mod sql; pub mod store; +mod options; #[cfg(feature = "substrait")] pub mod substrait; #[allow(clippy::borrow_deref_ref)] diff --git a/src/options.rs b/src/options.rs new file mode 100644 index 000000000..8b5a0979b --- /dev/null +++ b/src/options.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{dataframe::DataFrameWriteOptions, logical_expr::dml::InsertOp}; +use pyo3::{exceptions::PyValueError, FromPyObject, PyErr, PyResult}; + +use crate::expr::sort_expr::{to_sort_expressions, PySortExpr}; + +#[derive(FromPyObject)] +#[pyo3(from_item_all)] +pub struct PyDataFrameWriteOptions { + insert_operation: Option, + single_file_output: Option, + partition_by: Option>, + sort_by: Option>, +} + +impl TryInto for PyDataFrameWriteOptions { + type Error = PyErr; + + fn try_into(self) -> PyResult { + let mut options = DataFrameWriteOptions::new(); + if let Some(insert_op) = self.insert_operation { + let op = match insert_op.as_str() { + "append" => InsertOp::Append, + "overwrite" => InsertOp::Overwrite, + "replace" => InsertOp::Replace, + _ => { + return Err(PyValueError::new_err(format!( + "Unrecognized insert op {insert_op}" + ))) + } + }; + options = options.with_insert_operation(op); + } + if let Some(single_file_output) = self.single_file_output { + options = options.with_single_file_output(single_file_output); + } + + if let Some(partition_by) = self.partition_by { + options = options.with_partition_by(partition_by); + } + + if let Some(sort_by) = self.sort_by { + options = options.with_sort_by(to_sort_expressions(sort_by)); + } + + Ok(options) + } +} + +pub fn make_dataframe_write_options( + write_options: Option, +) -> PyResult { + if let Some(wo) = write_options { + wo.try_into() + } else { + Ok(DataFrameWriteOptions::new()) + } +}