diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index e968172c76..f657f28a6f 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2332,6 +2332,8 @@ def merge( right_join_ids: typing.Sequence[str], sort: bool, suffixes: tuple[str, str] = ("_x", "_y"), + left_index: bool = False, + right_index: bool = False, ) -> Block: conditions = tuple( (lid, rid) for lid, rid in zip(left_join_ids, right_join_ids) @@ -2339,34 +2341,52 @@ def merge( joined_expr, (get_column_left, get_column_right) = self.expr.relational_join( other.expr, type=how, conditions=conditions ) - result_columns = [] - matching_join_labels = [] left_post_join_ids = tuple(get_column_left[id] for id in left_join_ids) right_post_join_ids = tuple(get_column_right[id] for id in right_join_ids) - joined_expr, coalesced_ids = coalesce_columns( - joined_expr, left_post_join_ids, right_post_join_ids, how=how, drop=False - ) + if left_index or right_index: + # For some reason pandas coalesces two joining columns if one side is an index. + joined_expr, resolved_join_ids = coalesce_columns( + joined_expr, left_post_join_ids, right_post_join_ids + ) + else: + joined_expr, resolved_join_ids = resolve_col_join_ids( # type: ignore + joined_expr, + left_post_join_ids, + right_post_join_ids, + how=how, + drop=False, + ) + + result_columns = [] + matching_join_labels = [] + # Select left value columns for col_id in self.value_columns: if col_id in left_join_ids: key_part = left_join_ids.index(col_id) matching_right_id = right_join_ids[key_part] if ( - self.col_id_to_label[col_id] + right_index + or self.col_id_to_label[col_id] == other.col_id_to_label[matching_right_id] ): matching_join_labels.append(self.col_id_to_label[col_id]) - result_columns.append(coalesced_ids[key_part]) + result_columns.append(resolved_join_ids[key_part]) else: result_columns.append(get_column_left[col_id]) else: result_columns.append(get_column_left[col_id]) + + # Select right value columns for col_id in other.value_columns: if col_id in right_join_ids: if other.col_id_to_label[col_id] in matching_join_labels: pass + elif left_index: + key_part = right_join_ids.index(col_id) + result_columns.append(resolved_join_ids[key_part]) else: result_columns.append(get_column_right[col_id]) else: @@ -2377,11 +2397,22 @@ def merge( joined_expr = joined_expr.order_by( [ ordering.OrderingExpression(ex.deref(col_id)) - for col_id in coalesced_ids + for col_id in resolved_join_ids ], ) - joined_expr = joined_expr.select_columns(result_columns) + left_idx_id_post_join = [get_column_left[id] for id in self.index_columns] + right_idx_id_post_join = [get_column_right[id] for id in other.index_columns] + index_cols = _resolve_index_col( + left_idx_id_post_join, + right_idx_id_post_join, + resolved_join_ids, + left_index, + right_index, + how, + ) + + joined_expr = joined_expr.select_columns(result_columns + index_cols) labels = utils.merge_column_labels( self.column_labels, other.column_labels, @@ -2400,13 +2431,13 @@ def merge( or other.index.is_null or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL ): - expr = joined_expr - index_columns = [] + return Block(joined_expr, index_columns=[], column_labels=labels) + elif index_cols: + return Block(joined_expr, index_columns=index_cols, column_labels=labels) else: expr, offset_index_id = joined_expr.promote_offsets() index_columns = [offset_index_id] - - return Block(expr, index_columns=index_columns, column_labels=labels) + return Block(expr, index_columns=index_columns, column_labels=labels) def _align_both_axes( self, other: Block, how: str @@ -3115,7 +3146,7 @@ def join_mono_indexed( left_index = get_column_left[left.index_columns[0]] right_index = get_column_right[right.index_columns[0]] # Drop original indices from each side. and used the coalesced combination generated by the join. - combined_expr, coalesced_join_cols = coalesce_columns( + combined_expr, coalesced_join_cols = resolve_col_join_ids( combined_expr, [left_index], [right_index], how=how ) if sort: @@ -3180,7 +3211,7 @@ def join_multi_indexed( left_ids_post_join = [get_column_left[id] for id in left_join_ids] right_ids_post_join = [get_column_right[id] for id in right_join_ids] # Drop original indices from each side. and used the coalesced combination generated by the join. - combined_expr, coalesced_join_cols = coalesce_columns( + combined_expr, coalesced_join_cols = resolve_col_join_ids( combined_expr, left_ids_post_join, right_ids_post_join, how=how ) if sort: @@ -3223,13 +3254,17 @@ def resolve_label_id(label: Label) -> str: # TODO: Rewrite just to return expressions -def coalesce_columns( +def resolve_col_join_ids( expr: core.ArrayValue, left_ids: typing.Sequence[str], right_ids: typing.Sequence[str], how: str, drop: bool = True, ) -> Tuple[core.ArrayValue, Sequence[str]]: + """ + Collapses and selects the joining column IDs, with the assumption that + the ids are all belong to value columns. + """ result_ids = [] for left_id, right_id in zip(left_ids, right_ids): if how == "left" or how == "inner" or how == "cross": @@ -3241,7 +3276,6 @@ def coalesce_columns( if drop: expr = expr.drop_columns([left_id]) elif how == "outer": - coalesced_id = guid.generate_guid() expr, coalesced_id = expr.project_to_id( ops.coalesce_op.as_expr(left_id, right_id) ) @@ -3253,6 +3287,21 @@ def coalesce_columns( return expr, result_ids +def coalesce_columns( + expr: core.ArrayValue, + left_ids: typing.Sequence[str], + right_ids: typing.Sequence[str], +) -> tuple[core.ArrayValue, list[str]]: + result_ids = [] + for left_id, right_id in zip(left_ids, right_ids): + expr, coalesced_id = expr.project_to_id( + ops.coalesce_op.as_expr(left_id, right_id) + ) + result_ids.append(coalesced_id) + + return expr, result_ids + + def _cast_index(block: Block, dtypes: typing.Sequence[bigframes.dtypes.Dtype]): original_block = block result_ids = [] @@ -3468,3 +3517,35 @@ def _pd_index_to_array_value( rows.append(row) return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session) + + +def _resolve_index_col( + left_index_cols: list[str], + right_index_cols: list[str], + resolved_join_ids: list[str], + left_index: bool, + right_index: bool, + how: typing.Literal[ + "inner", + "left", + "outer", + "right", + "cross", + ], +) -> list[str]: + if left_index and right_index: + if how == "inner" or how == "left": + return left_index_cols + if how == "right": + return right_index_cols + if how == "outer": + return resolved_join_ids + else: + return [] + elif left_index and not right_index: + return right_index_cols + elif right_index and not left_index: + return left_index_cols + else: + # Joining with value columns only. Existing indices will be discarded. + return [] diff --git a/bigframes/core/reshape/merge.py b/bigframes/core/reshape/merge.py index 5c6cba4915..2afeb2a106 100644 --- a/bigframes/core/reshape/merge.py +++ b/bigframes/core/reshape/merge.py @@ -20,6 +20,7 @@ from typing import Literal, Sequence +from bigframes_vendored import constants import bigframes_vendored.pandas.core.reshape.merge as vendored_pandas_merge from bigframes import dataframe, series @@ -40,6 +41,8 @@ def merge( *, left_on: blocks.Label | Sequence[blocks.Label] | None = None, right_on: blocks.Label | Sequence[blocks.Label] | None = None, + left_index: bool = False, + right_index: bool = False, sort: bool = False, suffixes: tuple[str, str] = ("_x", "_y"), ) -> dataframe.DataFrame: @@ -59,35 +62,16 @@ def merge( ) return dataframe.DataFrame(result_block) - left_on, right_on = _validate_left_right_on( - left, right, on, left_on=left_on, right_on=right_on + left_join_ids, right_join_ids = _validate_left_right_on( + left, + right, + on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, ) - if utils.is_list_like(left_on): - left_on = list(left_on) # type: ignore - else: - left_on = [left_on] - - if utils.is_list_like(right_on): - right_on = list(right_on) # type: ignore - else: - right_on = [right_on] - - left_join_ids = [] - for label in left_on: # type: ignore - left_col_id = left._resolve_label_exact(label) - # 0 elements already throws an exception - if not left_col_id: - raise ValueError(f"No column {label} found in self.") - left_join_ids.append(left_col_id) - - right_join_ids = [] - for label in right_on: # type: ignore - right_col_id = right._resolve_label_exact(label) - if not right_col_id: - raise ValueError(f"No column {label} found in other.") - right_join_ids.append(right_col_id) - block = left._block.merge( right._block, how, @@ -95,6 +79,8 @@ def merge( right_join_ids, sort=sort, suffixes=suffixes, + left_index=left_index, + right_index=right_index, ) return dataframe.DataFrame(block) @@ -127,30 +113,106 @@ def _validate_left_right_on( *, left_on: blocks.Label | Sequence[blocks.Label] | None = None, right_on: blocks.Label | Sequence[blocks.Label] | None = None, -): - if on is not None: - if left_on is not None or right_on is not None: - raise ValueError( - "Can not pass both `on` and `left_on` + `right_on` params." - ) - return on, on - - if left_on is not None and right_on is not None: - return left_on, right_on + left_index: bool = False, + right_index: bool = False, +) -> tuple[list[str], list[str]]: + # Turn left_on and right_on to lists + if left_on is not None and not isinstance(left_on, (tuple, list)): + left_on = [left_on] + if right_on is not None and not isinstance(right_on, (tuple, list)): + right_on = [right_on] - left_cols = left.columns - right_cols = right.columns - common_cols = left_cols.intersection(right_cols) - if len(common_cols) == 0: + if left_index and left.index.nlevels > 1: raise ValueError( - "No common columns to perform merge on." - f"Merge options: left_on={left_on}, " - f"right_on={right_on}, " + f"Joining with multi-level index is not supported. {constants.FEEDBACK_LINK}" ) - if ( - not left_cols.join(common_cols, how="inner").is_unique - or not right_cols.join(common_cols, how="inner").is_unique - ): - raise ValueError(f"Data columns not unique: {repr(common_cols)}") + if right_index and right.index.nlevels > 1: + raise ValueError( + f"Joining with multi-level index is not supported. {constants.FEEDBACK_LINK}" + ) + + # The following checks are copied from Pandas. + if on is None and left_on is None and right_on is None: + if left_index and right_index: + return list(left._block.index_columns), list(right._block.index_columns) + elif left_index: + raise ValueError("Must pass right_on or right_index=True") + elif right_index: + raise ValueError("Must pass left_on or left_index=True") + else: + # use the common columns + common_cols = left.columns.intersection(right.columns) + if len(common_cols) == 0: + raise ValueError( + "No common columns to perform merge on. " + f"Merge options: left_on={left_on}, " + f"right_on={right_on}, " + f"left_index={left_index}, " + f"right_index={right_index}" + ) + if ( + not left.columns.join(common_cols, how="inner").is_unique + or not right.columns.join(common_cols, how="inner").is_unique + ): + raise ValueError(f"Data columns not unique: {repr(common_cols)}") + return _to_col_ids(left, common_cols.to_list()), _to_col_ids( + right, common_cols.to_list() + ) - return common_cols, common_cols + elif on is not None: + if left_on is not None or right_on is not None: + raise ValueError( + 'Can only pass argument "on" OR "left_on" ' + 'and "right_on", not a combination of both.' + ) + if left_index or right_index: + raise ValueError( + 'Can only pass argument "on" OR "left_index" ' + 'and "right_index", not a combination of both.' + ) + return _to_col_ids(left, on), _to_col_ids(right, on) + + elif left_on is not None: + if left_index: + raise ValueError( + 'Can only pass argument "left_on" OR "left_index" not both.' + ) + if not right_index and right_on is None: + raise ValueError('Must pass "right_on" OR "right_index".') + if right_index: + if len(left_on) != right.index.nlevels: + raise ValueError( + "len(left_on) must equal the number " + 'of levels in the index of "right"' + ) + return _to_col_ids(left, left_on), list(right._block.index_columns) + + elif right_on is not None: + if right_index: + raise ValueError( + 'Can only pass argument "right_on" OR "right_index" not both.' + ) + if not left_index and left_on is None: + raise ValueError('Must pass "left_on" OR "left_index".') + if left_index: + if len(right_on) != left.index.nlevels: + raise ValueError( + "len(right_on) must equal the number " + 'of levels in the index of "left"' + ) + return list(left._block.index_columns), _to_col_ids(right, right_on) + + # The user correctly specified left_on and right_on + if len(right_on) != len(left_on): # type: ignore + raise ValueError("len(right_on) must equal len(left_on)") + + return _to_col_ids(left, left_on), _to_col_ids(right, right_on) + + +def _to_col_ids( + df: dataframe.DataFrame, join_cols: blocks.Label | Sequence[blocks.Label] +) -> list[str]: + if utils.is_list_like(join_cols): + return [df._block.resolve_label_exact_or_error(col) for col in join_cols] + + return [df._block.resolve_label_exact_or_error(join_cols)] diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 7471cf587b..0ce602d1ea 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3650,6 +3650,8 @@ def merge( *, left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, right_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, + left_index: bool = False, + right_index: bool = False, sort: bool = False, suffixes: tuple[str, str] = ("_x", "_y"), ) -> DataFrame: @@ -3662,6 +3664,8 @@ def merge( on, left_on=left_on, right_on=right_on, + left_index=left_index, + right_index=right_index, sort=sort, suffixes=suffixes, ) diff --git a/tests/system/small/core/test_reshape.py b/tests/system/small/core/test_reshape.py new file mode 100644 index 0000000000..0850bf50bb --- /dev/null +++ b/tests/system/small/core/test_reshape.py @@ -0,0 +1,120 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pandas.testing +import pytest + +from bigframes import session +from bigframes.core.reshape import merge + + +@pytest.mark.parametrize( + ("left_on", "right_on", "left_index", "right_index"), + [ + ("col_a", None, False, True), + (None, "col_d", True, False), + (None, None, True, True), + ], +) +@pytest.mark.parametrize("how", ["inner", "left", "right", "outer"]) +def test_join_with_index( + session: session.Session, left_on, right_on, left_index, right_index, how +): + df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=[1, 2, 3]) + bf1 = session.read_pandas(df1) + df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=[2, 3, 4]) + bf2 = session.read_pandas(df2) + + bf_result = merge.merge( + bf1, + bf2, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + how=how, + ).to_pandas() + pd_result = pd.merge( + df1, + df2, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + how=how, + ) + + pandas.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +@pytest.mark.parametrize( + ("on", "left_on", "right_on", "left_index", "right_index"), + [ + (None, "col_a", None, True, False), + (None, None, "col_c", None, True), + ("col_a", None, None, True, True), + ], +) +def test_join_with_index_invalid_index_arg_raise_error( + session: session.Session, on, left_on, right_on, left_index, right_index +): + df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=[1, 2, 3]) + bf1 = session.read_pandas(df1) + df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=[2, 3, 4]) + bf2 = session.read_pandas(df2) + + with pytest.raises(ValueError): + merge.merge( + bf1, + bf2, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + ).to_pandas() + + +@pytest.mark.parametrize( + ("left_on", "right_on", "left_index", "right_index"), + [ + (["col_a", "col_b"], None, False, True), + (None, ["col_c", "col_d"], True, False), + (None, None, True, True), + ], +) +@pytest.mark.parametrize("how", ["inner", "left", "right", "outer"]) +def test_join_with_multiindex_raises_error( + session: session.Session, left_on, right_on, left_index, right_index, how +): + multi_idx1 = pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 5)]) + df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx1) + bf1 = session.read_pandas(df1) + multi_idx2 = pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 2)]) + df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=multi_idx2) + bf2 = session.read_pandas(df2) + + with pytest.raises(ValueError): + merge.merge( + bf1, + bf2, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + how=how, + ) diff --git a/tests/system/small/test_pandas.py b/tests/system/small/test_pandas.py index d2cde59729..2f4ddaecff 100644 --- a/tests/system/small/test_pandas.py +++ b/tests/system/small/test_pandas.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import datetime -import re import typing import pandas as pd @@ -440,10 +439,7 @@ def test_merge_raises_error_when_left_right_on_set(scalars_dfs): left = scalars_df[left_columns] right = scalars_df[right_columns] - with pytest.raises( - ValueError, - match=re.escape("Can not pass both `on` and `left_on` + `right_on` params."), - ): + with pytest.raises(ValueError): bpd.merge( left, right, diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 1e90e2e210..1e76454179 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4593,6 +4593,8 @@ def merge( *, left_on: Optional[str] = None, right_on: Optional[str] = None, + left_index: bool = False, + right_index: bool = False, sort: bool = False, suffixes: tuple[str, str] = ("_x", "_y"), ) -> DataFrame: @@ -4705,6 +4707,10 @@ def merge( right_on (label or list of labels): Columns to join on in the right DataFrame. Either on or left_on + right_on must be passed in. + left_index (bool, default False): + Use the index from the left DataFrame as the join key. + right_index (bool, default False): + Use the index from the right DataFrame as the join key. sort: Default False. Sort the join keys lexicographically in the result DataFrame. If False, the order of the join keys depends diff --git a/third_party/bigframes_vendored/pandas/core/reshape/merge.py b/third_party/bigframes_vendored/pandas/core/reshape/merge.py index 66fb2c2160..49ff409c9a 100644 --- a/third_party/bigframes_vendored/pandas/core/reshape/merge.py +++ b/third_party/bigframes_vendored/pandas/core/reshape/merge.py @@ -13,6 +13,8 @@ def merge( *, left_on=None, right_on=None, + left_index: bool = False, + right_index: bool = False, sort=False, suffixes=("_x", "_y"), ): @@ -61,6 +63,10 @@ def merge( right_on (label or list of labels): Columns to join on in the right DataFrame. Either on or left_on + right_on must be passed in. + left_index (bool, default False): + Use the index from the left DataFrame as the join key. + right_index (bool, default False): + Use the index from the right DataFrame as the join key. sort: Default False. Sort the join keys lexicographically in the result DataFrame. If False, the order of the join keys depends