From 9733d732c209956acc80c26da0ee365ecbb385a9 Mon Sep 17 00:00:00 2001 From: jialuo Date: Tue, 19 Aug 2025 00:26:30 +0000 Subject: [PATCH] feat: Support callable for series where method --- bigframes/series.py | 16 ++++++++ .../large/functions/test_managed_function.py | 36 +++++++++++++++++ .../large/functions/test_remote_function.py | 39 +++++++++++++++++++ tests/system/small/test_series.py | 20 ++++++++++ 4 files changed, 111 insertions(+) diff --git a/bigframes/series.py b/bigframes/series.py index 321a023e0c..6f48935ec9 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1478,7 +1478,23 @@ def items(self): for item in batch_df.squeeze(axis=1).items(): yield item + def _apply_callable(self, condition): + """ "Executes the possible callable condition as needed.""" + if callable(condition): + # When it's a bigframes function. + if hasattr(condition, "bigframes_bigquery_function"): + return self.apply(condition) + # When it's a plain Python function. + else: + return self.apply(condition, by_row=False) + + # When it's not a callable. + return condition + def where(self, cond, other=None): + cond = self._apply_callable(cond) + other = self._apply_callable(other) + value_id, cond_id, other_id, block = self._align3(cond, other) block, result_id = block.project_expr( ops.where_op.as_expr(value_id, cond_id, other_id) diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 209e4df1e3..262f5f0fe2 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1075,3 +1075,39 @@ def func_for_other(x): cleanup_function_assets( is_sum_positive_series_mf, session.bqclient, ignore_failures=False ) + + +def test_managed_function_series_where(session, dataset_id, scalars_dfs): + try: + + # The return type has to be bool type for callable where condition. + def _is_positive(s): + return s + 1000 > 0 + + is_positive_mf = session.udf( + input_types=int, + output_type=bool, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(_is_positive) + + scalars, scalars_pandas = scalars_dfs + + bf_int64 = scalars["int64_col"] + bf_int64_filtered = bf_int64.dropna() + pd_int64 = scalars_pandas["int64_col"] + pd_int64_filtered = pd_int64.dropna() + + # The cond is a callable (managed function) and the other is not a + # callable in series.where method. + bf_result = bf_int64_filtered.where( + cond=is_positive_mf, other=-bf_int64_filtered + ).to_pandas() + pd_result = pd_int64_filtered.where(cond=_is_positive, other=-pd_int64_filtered) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False) diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 558b292c49..9e2c1e2c81 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2930,3 +2930,42 @@ def func_for_other(x): cleanup_function_assets( is_sum_positive_series_mf, session.bqclient, ignore_failures=False ) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_series_where(session, dataset_id, scalars_dfs): + try: + + def _ten_times(x): + return x * 10 + + ten_times_mf = session.remote_function( + input_types=float, + output_type=float, + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + )(_ten_times) + + scalars, scalars_pandas = scalars_dfs + + bf_int64 = scalars["float64_col"] + bf_int64_filtered = bf_int64.dropna() + pd_int64 = scalars_pandas["float64_col"] + pd_int64_filtered = pd_int64.dropna() + + # The cond is not a callable and the other is a callable (remote + # function) in series.where method. + bf_result = bf_int64_filtered.where( + cond=bf_int64_filtered < 0, other=ten_times_mf + ).to_pandas() + pd_result = pd_int64_filtered.where( + cond=pd_int64_filtered < 0, other=_ten_times + ) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index aa9afa6032..2172962046 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -3109,6 +3109,26 @@ def test_where_with_default(scalars_df_index, scalars_pandas_df_index): ) +def test_where_with_callable(scalars_df_index, scalars_pandas_df_index): + def _is_positive(x): + return x > 0 + + # Both cond and other are callable. + bf_result = ( + scalars_df_index["int64_col"] + .where(cond=_is_positive, other=lambda x: x * 10) + .to_pandas() + ) + pd_result = scalars_pandas_df_index["int64_col"].where( + cond=_is_positive, other=lambda x: x * 10 + ) + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + @pytest.mark.parametrize( ("ordered"), [