diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index ec00e38606..e60bef1819 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -35,6 +35,7 @@ nodes.ProjectionNode, nodes.SliceNode, nodes.AggregateNode, + nodes.FilterNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_filtering.py b/tests/system/small/engines/test_filtering.py new file mode 100644 index 0000000000..9b7cd034b4 --- /dev/null +++ b/tests/system/small/engines/test_filtering.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from bigframes.core import array_value, expression, nodes +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_bool_col( + scalars_array_value: array_value.ArrayValue, + engine, +): + node = nodes.FilterNode( + scalars_array_value.node, predicate=expression.deref("bool_col") + ) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_expr_cond( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = ops.gt_op.as_expr( + expression.deref("float64_col"), expression.deref("int64_col") + ) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_true( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = expression.const(True) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_false( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = expression.const(False) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) diff --git a/tests/system/small/test_polars_execution.py b/tests/system/small/test_polars_execution.py index 0aed693b80..1568a76ec9 100644 --- a/tests/system/small/test_polars_execution.py +++ b/tests/system/small/test_polars_execution.py @@ -53,8 +53,7 @@ def test_polar_execution_sorted_filtered(session_w_polars, scalars_pandas_df_ind .to_pandas() ) - # Filter and isnull not supported by polar engine yet, so falls back to bq execution - assert session_w_polars._metrics.execution_count == (execution_count_before + 1) + assert session_w_polars._metrics.execution_count == execution_count_before assert_pandas_df_equal(bf_result, pd_result)