diff --git a/src/dataframe.rs b/src/dataframe.rs index 6fb08ba25..58533154b 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -33,6 +33,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; +use futures::{future, StreamExt}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -90,8 +91,13 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyDataFusionResult { - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + let df = self.df.as_ref().clone(); + + // Mostly the same functionality of `df.limit(0, 10).collect()`. But + // `df.limit(0, 10)` is a semantically different plan, which might be + // invalid. A case is df=`EXPLAIN ...` as `Explain` must be the root. + let batches: Vec = get_batches(py, df, 10)?; + let batches_as_string = pretty::pretty_format_batches(&batches); match batches_as_string { Ok(batch) => Ok(format!("DataFrame()\n{batch}")), @@ -102,8 +108,11 @@ impl PyDataFrame { fn _repr_html_(&self, py: Python) -> PyDataFusionResult { let mut html_str = "\n".to_string(); - let df = self.df.as_ref().clone().limit(0, Some(10))?; - let batches = wait_for_future(py, df.collect())?; + // Mostly the same functionality of `df.limit(0, 10).collect()`. But + // `df.limit(0, 10)` is a semantically different plan, which might be + // invalid. A case is df=`EXPLAIN ...` as `Explain` must be the root. + let df = self.df.as_ref().clone(); + let batches: Vec = get_batches(py, df, 10)?; if batches.is_empty() { html_str.push_str("
\n"); @@ -733,3 +742,47 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } + +/// get dataframe as a list of `RecordBatch`es containing at most `max_rows` rows. +fn get_batches( + py: Python, + df: DataFrame, + max_rows: usize, +) -> Result, PyDataFusionError> { + // Here uses `df.execute_stream_partitioned` instead of `df.execute_stream` + // as the later one internally appends `CoalescePartitionsExec` to merge + // the result into a signle partition thus might cause loading of + // unnecessary partitions. + let partitioned_stream = + wait_for_future(py, df.execute_stream_partitioned()).map_err(py_datafusion_err)?; + let stream = futures::stream::iter(partitioned_stream).flatten(); + wait_for_future( + py, + stream + .scan(0, |state, x| { + let total = *state; + if total >= max_rows { + // If scanning more than `max_rows`, then stop + future::ready(None) + } else { + match x { + Ok(batch) => { + if total + batch.num_rows() <= max_rows { + // Add the whole batch when not exceeding `max_rows` + *state = total + batch.num_rows(); + future::ready(Some(Ok(batch))) + } else { + // Partially load `max_rows - total` rows. + *state = max_rows; + future::ready(Some(Ok(batch.slice(0, max_rows - total)))) + } + } + Err(err) => future::ready(Some(Err(PyDataFusionError::from(err)))), + } + } + }) + .collect::>(), + ) + .into_iter() + .collect::, _>>() +}