From f9d9fe71a91ea2e34c9c327d1d0f6fced7b6ef30 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 17 Jul 2025 01:23:08 +0000 Subject: [PATCH] feat: Add isin local execution to hybrid engine --- bigframes/core/compile/polars/compiler.py | 24 +++++++++++++++++++++++ bigframes/core/rewrite/schema_binding.py | 10 ++++++++++ bigframes/session/polars_executor.py | 1 + tests/system/small/engines/test_join.py | 19 ++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index c31c122078..5c29974c8a 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -503,6 +503,30 @@ def compile_join(self, node: nodes.JoinNode): left, right, node.type, left_on, right_on, node.joins_nulls ) + @compile_node.register + def compile_isin(self, node: nodes.InNode): + left = self.compile_node(node.left_child) + right = self.compile_node(node.right_child).unique(node.right_col.id.sql) + right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql)) + + left_ex, right_ex = lowering._coerce_comparables(node.left_col, node.right_col) + + left_pl_ex = self.expr_compiler.compile_expression(left_ex) + right_pl_ex = self.expr_compiler.compile_expression(right_ex) + + joined = left.join( + right, + how="left", + left_on=left_pl_ex, + right_on=right_pl_ex, + # Note: join_nulls renamed to nulls_equal for polars 1.24 + join_nulls=node.joins_nulls, # type: ignore + coalesce=False, + ) + passthrough = [pl.col(id) for id in left.columns] + indicator = pl.col(node.indicator_col.sql).fill_null(False) + return joined.select((*passthrough, indicator)) + def _ordered_join( self, left_frame: pl.LazyFrame, diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index 40a00ff8f6..f7f2ca8c59 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -65,6 +65,16 @@ def bind_schema_to_node( node, conditions=conditions, ) + if isinstance(node, nodes.InNode): + return dataclasses.replace( + node, + left_col=ex.ResolvedDerefOp.from_field( + node.left_child.field_by_id[node.left_col.id] + ), + right_col=ex.ResolvedDerefOp.from_field( + node.right_child.field_by_id[node.right_col.id] + ), + ) if isinstance(node, nodes.AggregateNode): aggregations = [] diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 3c23e4c200..c6aaadb7a1 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -38,6 +38,7 @@ nodes.FilterNode, nodes.ConcatNode, nodes.JoinNode, + nodes.InNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_join.py b/tests/system/small/engines/test_join.py index 402a41134b..91c199a437 100644 --- a/tests/system/small/engines/test_join.py +++ b/tests/system/small/engines/test_join.py @@ -88,3 +88,22 @@ def test_engines_cross_join( result, _ = scalars_array_value.relational_join(scalars_array_value, type="cross") assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize( + ("left_key", "right_key"), + [ + ("int64_col", "float64_col"), + ("float64_col", "int64_col"), + ("int64_too", "int64_col"), + ], +) +def test_engines_isin( + scalars_array_value: array_value.ArrayValue, engine, left_key, right_key +): + result, _ = scalars_array_value.isin( + scalars_array_value, lcol=left_key, rcol=right_key + ) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)