8000 Support async iteration of RecordBatchStream (#975) · davisp/datafusion-python@4b262be · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 4b262be

Browse files
Support async iteration of RecordBatchStream (apache#975)
* Support async iteration of RecordBatchStream * use __anext__ * use await * fix failing test * Since we are raising an error instead of returning a None, we can update the type hint. --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent 389164a commit 4b262be

File tree

5 files changed

+69
-19
lines changed

5 files changed

+69
-19
lines changed

Cargo.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"]
3636
[dependencies]
3737
tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3838
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
39+
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
3940
arrow = { version = "53", features = ["pyarrow"] }
4041
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
4142
datafusion-substrait = { version = "43.0.0", optional = true }
@@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"]
6061

6162
[profile.release]
6263
lto = true
63-
codegen-units = 1
64+
codegen-units = 1

python/datafusion/record_batch.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,24 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None:
5757
"""This constructor is typically not called by the end user."""
5858
self.rbs = record_batch_stream
5959

60-
def next(self) -> RecordBatch | None:
60+
def next(self) -> RecordBatch:
6161
"""See :py:func:`__next__` for the iterator function."""
62-
try:
63-
next_batch = next(self)
64-
except StopIteration:
65-
return None
62+
return next(self)
6663

67-
return next_batch
64+
async def __anext__(self) -> RecordBatch:
65+
"""Async iterator function."""
66+
next_batch = await self.rbs.__anext__()
67+
return RecordBatch(next_batch)
6868

6969
def __next__(self) -> RecordBatch:
7070
"""Iterator function."""
7171
next_batch = next(self.rbs)
7272
return RecordBatch(next_batch)
7373

74+
def __aiter__(self) -> typing_extensions.Self:
75+
"""Async iterator function."""
76+
return self
77+
7478
def __iter__(self) -> typing_extensions.Self:
7579
"""Iterator function."""
7680
return self

python/tests/test_dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df):
761761
batch = stream.next()
762762
assert batch is not None
763763
# there should be no more batches
764-
batch = stream.next()
765-
assert batch is None
764+
with pytest.raises(StopIteration):
765+
stream.next()
766766

767767

768768
def test_repartition(df):

src/record_batch.rs

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use crate::utils::wait_for_future;
1921
use datafusion::arrow::pyarrow::ToPyArrow;
2022
use datafusion::arrow::record_batch::RecordBatch;
2123
use datafusion::physical_plan::SendableRecordBatchStream;
2224
use futures::StreamExt;
25+
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
2326
use pyo3::prelude::*;
2427
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
28+
use tokio::sync::Mutex;
2529

2630
#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
2731
pub struct PyRecordBatch {
@@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {
4347

4448
#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
4549
pub struct PyRecordBatchStream {
46-
stream: SendableRecordBatchStream,
50+
stream: Arc<Mutex<SendableRecordBatchStream>>,
4751
}
4852

4953
impl PyRecordBatchStream {
5054
pub fn new(stream: SendableRecordBatchStream) -> Self {
51-
Self { stream }
55+
Self {
56+
stream: Arc::new(Mutex::new(stream)),
57+
}
5258
}
5359
}
5460

5561
#[pymethods]
5662
impl PyRecordBatchStream {
57-
fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
58-
let result = self.stream.next();
59-
match wait_for_future(py, result) {
60-
None => Ok(None),
61-
Some(Ok(b)) => Ok(Some(b.into())),
62-
Some(Err(e)) => Err(e.into()),
63-
}
63+
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
64+
let stream = self.stream.clone();
65+
wait_for_future(py, next_stream(stream, true))
6466
}
6567

66-
fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
68+
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
6769
self.next(py)
6870
}
6971

72+
fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
73+
let stream = self.stream.clone();
74+
pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
75+
}
76+
7077
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
7178
slf
7279
}
80+
81+
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
82+
slf
83+
}
84+
}
85+
86+
async fn next_stream(
87+
stream: Arc<Mutex<SendableRecordBatchStream>>,
88+
sync: bool,
89+
) -> PyResult<PyRecordBatch> {
90+
let mut stream = stream.lock().await;
91+
match stream.next().await {
92+
Some(Ok(batch)) => Ok(batch.into()),
93+
Some(Err(e)) => Err(e.into()),
94+
None => {
95+
// Depending on whether the iteration is sync or not, we raise either a
96+
// StopIteration or a StopAsyncIteration
97+
if sync {
98+
Err(PyStopIteration::new_err("stream exhausted"))
99+
} else {
100+
Err(PyStopAsyncIteration::new_err("stream exhausted"))
101+
}
102+
}
103+
}
73104
}

0 commit comments

Comments
 (0)
0