diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index a65ff6fe0c..06d889beaa 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -588,6 +588,11 @@ def endswith_op_impl(x: ibis_types.Value, op: ops.EndsWithOp): return any_match if any_match is not None else ibis_types.literal(False) +@scalar_op_compiler.register_unary_op(ops.StringSplitOp, pass_op=True) +def stringsplit_op_impl(x: ibis_types.Value, op: ops.StringSplitOp): + return typing.cast(ibis_types.StringValue, x).split(op.pat) + + @scalar_op_compiler.register_unary_op(ops.ZfillOp, pass_op=True) def zfill_op_impl(x: ibis_types.Value, op: ops.ZfillOp): str_value = typing.cast(ibis_types.StringValue, x) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index d2dc210e0d..2a344aff2d 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -405,6 +405,12 @@ def bigframes_dtype_to_ibis_dtype( return BIGFRAMES_TO_IBIS[bigframes_dtype] +def bigframes_dtype_to_arrow_dtype( + bigframes_dtype: Union[DtypeString, Dtype, np.dtype[Any]] +) -> pa.DataType: + return ibis_dtype_to_arrow_dtype(bigframes_dtype_to_ibis_dtype(bigframes_dtype)) + + def literal_to_ibis_scalar( literal, force_dtype: typing.Optional[Dtype] = None, validate: bool = True ): diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index e52f488d38..e8b79af58e 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -386,6 +386,19 @@ def output_type(self, *input_types): return op_typing.STRING_PREDICATE.output_type(input_types[0]) +@dataclasses.dataclass(frozen=True) +class StringSplitOp(UnaryOp): + name: typing.ClassVar[str] = "str_split" + pat: typing.Sequence[str] + + def output_type(self, *input_types): + input_type = input_types[0] + if not isinstance(input_type, pd.StringDtype): + raise TypeError("field accessor input must be a string type") + arrow_type = dtypes.bigframes_dtype_to_arrow_dtype(input_type) + return pd.ArrowDtype(pa.list_(arrow_type)) + + @dataclasses.dataclass(frozen=True) class EndsWithOp(UnaryOp): name: typing.ClassVar[str] = "str_endswith" @@ -463,9 +476,7 @@ def output_type(self, *input_types): raise TypeError("field accessor input must be a struct type") pa_result_type = pa_type[self.name_or_index].type - # TODO: Directly convert from arrow to pandas type - ibis_result_type = dtypes.arrow_dtype_to_ibis_dtype(pa_result_type) - return dtypes.ibis_dtype_to_bigframes_dtype(ibis_result_type) + return dtypes.arrow_dtype_to_bigframes_dtype(pa_result_type) @dataclasses.dataclass(frozen=True) diff --git a/bigframes/operations/strings.py b/bigframes/operations/strings.py index 883d19a1e3..22c325d7e0 100644 --- a/bigframes/operations/strings.py +++ b/bigframes/operations/strings.py @@ -247,6 +247,18 @@ def endswith( pat = (pat,) return self._apply_unary_op(ops.EndsWithOp(pat=pat)) + def split( + self, + pat: str = " ", + regex: Union[bool, None] = None, + ) -> series.Series: + if regex is True or (regex is None and len(pat) > 1): + raise NotImplementedError( + "Regular expressions aren't currently supported. Please set " + + f"`regex=False` and try again. {constants.FEEDBACK_LINK}" + ) + return self._apply_unary_op(ops.StringSplitOp(pat=pat)) + def zfill(self, width: int) -> series.Series: return self._apply_unary_op(ops.ZfillOp(width=width)) diff --git a/tests/system/small/operations/test_strings.py b/tests/system/small/operations/test_strings.py index 9654c77ec4..b8a8ad2d1e 100644 --- a/tests/system/small/operations/test_strings.py +++ b/tests/system/small/operations/test_strings.py @@ -531,3 +531,34 @@ def test_str_rjust(scalars_dfs): pd_result, bf_result, ) + + +@pytest.mark.parametrize( + ("pat", "regex"), + [ + pytest.param(" ", None, id="one_char"), + pytest.param("ll", False, id="two_chars"), + pytest.param( + " ", + True, + id="one_char_reg", + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + pytest.param( + "ll", + None, + id="two_chars_reg", + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], +) +def test_str_split_raise_errors(scalars_dfs, pat, regex): + scalars_df, scalars_pandas_df = scalars_dfs + col_name = "string_col" + bf_result = scalars_df[col_name].str.split(pat=pat, regex=regex).to_pandas() + pd_result = scalars_pandas_df[col_name].str.split(pat=pat, regex=regex) + + # TODO(b/336880368): Allow for NULL values for ARRAY columns in BigQuery. + pd_result = pd_result.apply(lambda x: [] if pd.isnull(x) is True else x) + + assert_series_equal(pd_result, bf_result, check_dtype=False) diff --git a/third_party/bigframes_vendored/pandas/core/strings/accessor.py b/third_party/bigframes_vendored/pandas/core/strings/accessor.py index 5bb69dc1f2..b02c23f945 100644 --- a/third_party/bigframes_vendored/pandas/core/strings/accessor.py +++ b/third_party/bigframes_vendored/pandas/core/strings/accessor.py @@ -940,6 +940,54 @@ def endswith( """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def split( + self, + pat: str = " ", + regex: typing.Union[bool, None] = None, + ): + """ + Split strings around given separator/delimiter. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import numpy as np + >>> bpd.options.display.progress_bar = None + + >>> s = bpd.Series( + ... [ + ... "a regular sentence", + ... "https://docs.python.org/index.html", + ... np.nan + ... ] + ... ) + >>> s.str.split() + 0 ['a' 'regular' 'sentence'] + 1 ['https://docs.python.org/index.html'] + 2 [] + dtype: list[pyarrow] + + The pat parameter can be used to split by other characters. + + >>> s.str.split("//", regex=False) + 0 ['a regular sentence'] + 1 ['https:' 'docs.python.org/index.html'] + 2 [] + dtype: list[pyarrow] + + Args: + pat (str, default " "): + String to split on. If not specified, split on whitespace. + regex (bool, default None): + Determines if the passed-in pattern is a regular expression. Regular + expressions aren't currently supported. Please set `regex=False` when + `pat` length is not 1. + + Returns: + bigframes.series.Series: Type matches caller. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def match(self, pat: str, case: bool = True, flags: int = 0): """ Determine if each string starts with a match of a regular expression.