8000 feat: add execute_stream and execute_stream_partitioned (#610) · samuelcolvin/datafusion-python@5e97701 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e97701

Browse files
authored
feat: add execute_stream and execute_stream_partitioned (apache#610)
1 parent 6a895c6 commit 5e97701

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

datafusion/tests/test_dataframe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,36 @@ def test_to_arrow_table(df):
623623
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
624624

625625

626+
def test_execute_stream(df):
627+
stream = df.execute_stream()
628+
assert all(batch is not None for batch in stream)
629+
assert not list(stream) # after one iteration the generator must be exhausted
630+
631+
632+
@pytest.mark.parametrize("schema", [True, False])
633+
def test_execute_stream_to_arrow_table(df, schema):
634+
stream = df.execute_stream()
635+
636+
if schema:
637+
pyarrow_table = pa.Table.from_batches(
638+
(batch.to_pyarrow() for batch in stream), schema=df.schema()
639+
)
640+
else:
641+
pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream))
642+
643+
assert isinstance(pyarrow_table, pa.Table)
644+
assert pyarrow_table.shape == (3, 3)
645+
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
646+
647+
648+
def test_execute_stream_partitioned(df):
649+
streams = df.execute_stream_partitioned()
650+
assert all(batch is not None for stream in streams for batch in stream)
651+
assert all(
652+
not list(stream) for stream in streams
653+
) # after one iteration all generators must be exhausted
654+
655+
626656
def test_empty_to_arrow_table(df):
627657
# Convert empty datafusion dataframe to pyarrow Table
628658
pyarrow_table = df.limit(0).to_arrow_table()

src/dataframe.rs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,27 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::physical_plan::PyExecutionPlan;
19-
use crate::sql::logical::PyLogicalPlan;
20-
use crate::utils::wait_for_future;
21-
use crate::{errors::DataFusionError, expr::PyExpr};
18+
use std::sync::Arc;
19+
2220
use datafusion::arrow::datatypes::Schema;
2321
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2422
use datafusion::arrow::util::pretty;
2523
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
24+
use datafusion::execution::SendableRecordBatchStream;
2625
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
2726
use datafusion::parquet::file::properties::WriterProperties;
2827
use datafusion::prelude::*;
2928
use pyo3::exceptions::{PyTypeError, PyValueError};
3029
use pyo3::prelude::*;
3130
use pyo3::types::PyTuple;
32-
use std::sync::Arc;
31+
use tokio::task::JoinHandle;
32+
33+
use crate::errors::py_datafusion_err;
34+
use crate::physical_plan::PyExecutionPlan;
35+
use crate::record_batch::PyRecordBatchStream;
36+
use crate::sql::logical::PyLogicalPlan;
37+
use crate::utils::{get_tokio_runtime, wait_for_future};
38+
use crate::{errors::DataFusionError, expr::PyExpr};
3339

3440
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
3541
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -399,6 +405,35 @@ impl PyDataFrame {
399405
})
400406
}
401407

408+
fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
409+
// create a Tokio runtime to run the async code
410+
let rt = &get_tokio_runtime(py).0;
411+
let df = self.df.as_ref().clone();
412+
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
413+
rt.spawn(async move { df.execute_stream().await });
414+
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
415+
Ok(PyRecordBatchStream::new(stream?))
416+
}
417+
418+
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
419+
// create a Tokio runtime to run the async code
420+
let rt = &get_tokio_runtime(py).0;
421+
let df = self.df.as_ref().clone();
422+
let fut: JoinHandle<datafusion_common::Result<Vec<SendableRecordBatchStream>>> =
423+
rt.spawn(async move { df.execute_stream_partitioned().await });
424+
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
425+
426+
match stream {
427+
Ok(batches) => Ok(batches
428+
.into_iter()
429+
.map(|batch_stream| PyRecordBatchStream::new(batch_stream))
430+
.collect()),
431+
_ => Err(PyValueError::new_err(
432+
"Unable to execute stream partitioned",
433+
)),
434+
}
435+
}
436+
402437
/// Convert to pandas dataframe with pyarrow
403438
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
404439
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {

src/record_batch.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use datafusion::arrow::pyarrow::ToPyArrow;
2020
use datafusion::arrow::record_batch::RecordBatch;
2121
use datafusion::physical_plan::SendableRecordBatchStream;
2222
use futures::StreamExt;
23+
use pyo3::prelude::*;
2324
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
2425

2526
#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
@@ -61,4 +62,12 @@ impl PyRecordBatchStream {
6162
Some(Err(e)) => Err(e.into()),
6263
}
6364
}
65+
66+
fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
67+
self.next(py)
68+
}
69+
70+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
71+
slf
72+
}
6473
}

0 commit comments

Comments
 (0)
0