8000 Pyarrow filter pushdowns (#735) · PhVHoang/datafusion-python@45a6844 · GitHub
[go: up one dir, main page]

Skip to content

Commit 45a6844

Browse files
Pyarrow filter pushdowns (apache#735)
* fix pushdown for pyarrow filter IsNull The conversion was incorrectly passing in the expression itself as the `nan_as_null` argument. This caused the pushdown to silently fail. * expand the Expr::Literal's that can be used in PyArrowFilterExpression Closes apache#703
1 parent 532dc38 commit 45a6844

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

python/datafusion/tests/test_context.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import gzip
1818
import os
19+
import datetime as dt
1920

2021
import pyarrow as pa
2122
import pyarrow.dataset as ds
@@ -303,6 +304,59 @@ def test_dataset_filter(ctx, capfd):
303304
assert result[0].column(1) == pa.array([-3])
304305

305306

307+
def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
308+
"""Ensure that pyarrow filter gets pushed down for `IsNull`"""
309+
# create a RecordBatch and register it as a pyarrow.dataset.Dataset
310+
batch = pa.RecordBatch.from_arrays(
311+
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])],
312+
names=["a", "b", "c"],
313+
)
314+
dataset = ds.dataset([batch])
315+
ctx.register_dataset("t", dataset)
316+
# Make sure the filter was pushed down in Physical Plan
317+
df = ctx.sql("SELECT a FROM t WHERE c is NULL")
318+
df.explain()
319+
captured = capfd.readouterr()
320+
assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out
321+
322+
result = df.collect()
323+
assert result[0].column(0) == pa.array([2])
324+
325+
326+
def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd):
327+
"""Ensure that pyarrow filter gets pushed down for timestamp"""
328+
# Ref: https://github.com/apache/datafusion-python/issues/703
329+
330+
# create pyarrow dataset with no actual files
331+
col_type = pa.timestamp("ns", "+00:00")
332+
nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type)
333+
pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem())
334+
pa_dataset_format = pa.dataset.ParquetFileFormat()
335+
pa_dataset_partition = pa.dataset.field("a") <= nyd_2000
336+
fragments = [
337+
# NOTE: we never actually make this file.
338+
# Working predicate pushdown means it never gets accessed
339+
pa_dataset_format.make_fragment(
340+
"1.parquet",
341+
filesystem=pa_dataset_fs,
342+
partition_expression=pa_dataset_partition,
343+
)
344+
]
345+
pa_dataset = pa.dataset.FileSystemDataset(
346+
fragments,
347+
pa.schema([pa.field("a", col_type)]),
348+
pa_dataset_format,
349+
pa_dataset_fs,
350+
)
351+
352+
ctx.register_dataset("t", pa_dataset)
353+
354+
# the partition for our only fragment is for a < 2000-01-01.
355+
# so querying for a > 2024-01-01 should not touch any files
356+
df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'")
357+
assert df.collect() == []
358+
359+
306360
def test_dataset_filter_nested_data(ctx):
307361
# create Arrow StructArrays to test nested data types
308362
data = pa.StructArray.from_arrays(

src/pyarrow_filter_expression.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use pyo3::prelude::*;
2121
use std::convert::TryFrom;
2222
use std::result::Result;
2323

24+
use arrow::pyarrow::ToPyArrow;
2425
use datafusion_common::{Column, ScalarValue};
2526
use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};
2627

@@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result<Vec<PyObject>, Data
5657
let ret: Result<Vec<PyObject>, DataFusionError> = exprs
5758
.iter()
5859
.map(|expr| match expr {
60+
// TODO: should we also leverage `ScalarValue::to_pyarrow` here?
5961
Expr::Literal(v) => match v {
6062
ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
6163
ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
@@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
100102
let op_module = Python::import_bound(py, "operator")?;
101103
let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match expr {
102104
Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?),
103-
Expr::Literal(v) => match v {
104-
ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?),
105-
ScalarValue::Int8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
106-
ScalarValue::Int16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
107-
ScalarValue::Int32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
108-
ScalarValue::Int64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
109-
ScalarValue::UInt8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
110-
ScalarValue::UInt16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
111-
ScalarValue::UInt32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
112-
ScalarValue::UInt64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
113-
ScalarValue::Float32(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?),
114-
ScalarValue::Float64(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?),
115-
ScalarValue::Utf8(Some(s)) => Ok(pc.getattr("scalar")?.call1((s,))?),
116-
_ => Err(DataFusionError::Common(format!(
117-
"PyArrow can't handle ScalarValue: {v:?}"
118-
))),
119-
},
105+
Expr::Literal(scalar) => Ok(scalar.to_pyarrow(py)?.into_bound(py)),
120106
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
121107
let operator = operator_to_py(op, &op_module)?;
122108
let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;
@@ -138,8 +124,13 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
138124
let expr = PyArrowFilterExpression::try_from(expr.as_ref())?
139125
.0
140126
.into_bound(py);
141-
// TODO: this expression does not seems like it should be `call_method0`
142-
Ok(expr.clone().call_method1("is_null", (expr,))?)
127+
128+
// https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression.is_null
129+
// Whether floating-point NaNs are considered null.
130+
let nan_is_null = false;
131+
132+
let res = expr.call_method1("is_null", (nan_is_null,))?;
133+
Ok(res)
143134
}
144135
Expr::Between(Between {
145136
expr,

0 commit comments

Comments
 (0)
0