diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 40037735d4..dfa2ebc818 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -547,6 +547,11 @@ def compile_concat(self, node: nodes.ConcatNode): child_frames = [ frame.rename( {col: id.sql for col, id in zip(frame.columns, node.output_ids)} + ).cast( + { + field.id.sql: _bigframes_dtype_to_polars_dtype(field.dtype) + for field in node.fields + } ) for frame in child_frames ] diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 5d3c814437..cf6e8a7e5c 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -424,7 +424,7 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class ConcatNode(BigFrameNode): - # TODO: Explcitly map column ids from each child + # TODO: Explcitly map column ids from each child? children: Tuple[BigFrameNode, ...] output_ids: Tuple[identifiers.ColumnId, ...] diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 28ab421905..8f669901a4 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -36,6 +36,7 @@ nodes.SliceNode, nodes.AggregateNode, nodes.FilterNode, + nodes.ConcatNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_concat.py b/tests/system/small/engines/test_concat.py new file mode 100644 index 0000000000..e10570fab2 --- /dev/null +++ b/tests/system/small/engines/test_concat.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value, ordering +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_concat_self( + scalars_array_value: array_value.ArrayValue, + engine, +): + result = scalars_array_value.concat([scalars_array_value, scalars_array_value]) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_concat_filtered_sorted( + scalars_array_value: array_value.ArrayValue, + engine, +): + input_1 = scalars_array_value.select_columns(["float64_col", "int64_col"]).order_by( + [ordering.ascending_over("int64_col")] + ) + input_2 = scalars_array_value.filter_by_id("bool_col").select_columns( + ["float64_col", "int64_too"] + ) + + result = input_1.concat([input_2, input_1, input_2]) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)