From e300ed13658368e38d39919a49a0170ef1223891 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 25 Aug 2025 13:38:48 -0700 Subject: [PATCH 01/28] chore: implement StrPadOp, StrFindOp, StrExtractOp, StrRepeatOp, RegexReplaceStrOp and ReplaceStrOp compilers (#2015) --- .../sqlglot/expressions/unary_compiler.py | 105 ++++++++++++++++-- .../test_regex_replace_str/out.sql | 13 +++ .../test_replace_str/out.sql | 13 +++ .../test_str_extract/out.sql | 13 +++ .../test_unary_compiler/test_str_find/out.sql | 13 +++ .../test_str_find/out_with_end.sql | 13 +++ .../test_str_find/out_with_start.sql | 13 +++ .../test_str_find/out_with_start_and_end.sql | 13 +++ .../test_unary_compiler/test_str_pad/both.sql | 21 ++++ .../test_unary_compiler/test_str_pad/left.sql | 13 +++ .../test_str_pad/right.sql | 13 +++ .../test_str_repeat/out.sql | 13 +++ .../expressions/test_unary_compiler.py | 58 ++++++++++ 13 files changed, 306 insertions(+), 8 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index ddaf04ae97..a5cffdc10a 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -177,14 +177,96 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StrContainsOp) +def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression: + return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%")) + + @UNARY_OP_REGISTRATION.register(ops.StrContainsRegexOp) def _(op: ops.StrContainsRegexOp, expr: TypedExpr) -> sge.Expression: return sge.RegexpLike(this=expr.expr, expression=sge.convert(op.pat)) -@UNARY_OP_REGISTRATION.register(ops.StrContainsOp) -def _(op: ops.StrContainsOp, expr: TypedExpr) -> sge.Expression: - return sge.Like(this=expr.expr, expression=sge.convert(f"%{op.pat}%")) +@UNARY_OP_REGISTRATION.register(ops.StrExtractOp) +def _(op: ops.StrExtractOp, expr: TypedExpr) -> sge.Expression: + return sge.RegexpExtract( + this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n) + ) + + +@UNARY_OP_REGISTRATION.register(ops.StrFindOp) +def _(op: ops.StrFindOp, expr: TypedExpr) -> sge.Expression: + # INSTR is 1-based, so we need to adjust the start position. + start = sge.convert(op.start + 1) if op.start is not None else sge.convert(1) + if op.end is not None: + # BigQuery's INSTR doesn't support `end`, so we need to use SUBSTR. + return sge.func( + "INSTR", + sge.Substring( + this=expr.expr, + start=start, + length=sge.convert(op.end - (op.start or 0)), + ), + sge.convert(op.substr), + ) - sge.convert(1) + else: + return sge.func( + "INSTR", + expr.expr, + sge.convert(op.substr), + start, + ) - sge.convert(1) + + +@UNARY_OP_REGISTRATION.register(ops.StrLstripOp) +def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression: + return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") + + +@UNARY_OP_REGISTRATION.register(ops.StrPadOp) +def _(op: ops.StrPadOp, expr: TypedExpr) -> sge.Expression: + pad_length = sge.func( + "GREATEST", sge.Length(this=expr.expr), sge.convert(op.length) + ) + if op.side == "left": + return sge.func( + "LPAD", + expr.expr, + pad_length, + sge.convert(op.fillchar), + ) + elif op.side == "right": + return sge.func( + "RPAD", + expr.expr, + pad_length, + sge.convert(op.fillchar), + ) + else: # side == both + lpad_amount = sge.Cast( + this=sge.func( + "SAFE_DIVIDE", + sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)), + sge.convert(2), + ), + to="INT64", + ) + sge.Length(this=expr.expr) + return sge.func( + "RPAD", + sge.func( + "LPAD", + expr.expr, + lpad_amount, + sge.convert(op.fillchar), + ), + pad_length, + sge.convert(op.fillchar), + ) + + +@UNARY_OP_REGISTRATION.register(ops.StrRepeatOp) +def _(op: ops.StrRepeatOp, expr: TypedExpr) -> sge.Expression: + return sge.Repeat(this=expr.expr, times=sge.convert(op.repeats)) @UNARY_OP_REGISTRATION.register(ops.date_op) @@ -444,11 +526,6 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="MONTH"), expression=expr.expr) -@UNARY_OP_REGISTRATION.register(ops.StrLstripOp) -def _(op: ops.StrLstripOp, expr: TypedExpr) -> sge.Expression: - return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT") - - @UNARY_OP_REGISTRATION.register(ops.neg_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Neg(this=expr.expr) @@ -484,6 +561,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="QUARTER"), expression=expr.expr) +@UNARY_OP_REGISTRATION.register(ops.ReplaceStrOp) +def _(op: ops.ReplaceStrOp, expr: TypedExpr) -> sge.Expression: + return sge.func("REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl)) + + +@UNARY_OP_REGISTRATION.register(ops.RegexReplaceStrOp) +def _(op: ops.RegexReplaceStrOp, expr: TypedExpr) -> sge.Expression: + return sge.func( + "REGEXP_REPLACE", expr.expr, sge.convert(op.pat), sge.convert(op.repl) + ) + + @UNARY_OP_REGISTRATION.register(ops.reverse_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.func("REVERSE", expr.expr) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql new file mode 100644 index 0000000000..149df6706c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_regex_replace_str/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REGEXP_REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql new file mode 100644 index 0000000000..3bd7e0e47e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_replace_str/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REPLACE(`bfcol_0`, 'e', 'a') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql new file mode 100644 index 0000000000..a7fac093e2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_extract/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REGEXP_EXTRACT(`bfcol_0`, '([a-z]*)') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql new file mode 100644 index 0000000000..dfc100e413 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql new file mode 100644 index 0000000000..78edf662b9 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql new file mode 100644 index 0000000000..d0dfc11a53 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql new file mode 100644 index 0000000000..a91ab32946 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql new file mode 100644 index 0000000000..4701b0237a --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql @@ -0,0 +1,21 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + RPAD( + LPAD( + `bfcol_0`, + CAST(SAFE_DIVIDE(GREATEST(LENGTH(`bfcol_0`), 10) - LENGTH(`bfcol_0`), 2) AS INT64) + LENGTH(`bfcol_0`), + '-' + ), + GREATEST(LENGTH(`bfcol_0`), 10), + '-' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql new file mode 100644 index 0000000000..ee95900b3e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql new file mode 100644 index 0000000000..17e59c553f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql new file mode 100644 index 0000000000..1c94cfafe2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_repeat/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + REPEAT(`bfcol_0`, 2) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 4a5b586c77..5c51068ce7 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -431,6 +431,18 @@ def test_quarter(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_replace_str(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.ReplaceStrOp("e", "a"), "string_col") + snapshot.assert_match(sql, "out.sql") + + +def test_regex_replace_str(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.RegexReplaceStrOp(r"e", "a"), "string_col") + snapshot.assert_match(sql, "out.sql") + + def test_reverse(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.reverse_op, "string_col") @@ -466,6 +478,24 @@ def test_str_get(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_str_pad(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="left"), "string_col" + ) + snapshot.assert_match(sql, "left.sql") + + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="right"), "string_col" + ) + snapshot.assert_match(sql, "right.sql") + + sql = _apply_unary_op( + bf_df, ops.StrPadOp(length=10, fillchar="-", side="both"), "string_col" + ) + snapshot.assert_match(sql, "both.sql") + + def test_str_slice(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrSliceOp(1, 3), "string_col") @@ -506,6 +536,34 @@ def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_str_extract(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrExtractOp(r"([a-z]*)", 1), "string_col") + + snapshot.assert_match(sql, "out.sql") + + +def test_str_repeat(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrRepeatOp(2), "string_col") + snapshot.assert_match(sql, "out.sql") + + +def test_str_find(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=None), "string_col") + snapshot.assert_match(sql, "out.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=None), "string_col") + snapshot.assert_match(sql, "out_with_start.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=5), "string_col") + snapshot.assert_match(sql, "out_with_end.sql") + + sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=5), "string_col") + snapshot.assert_match(sql, "out_with_start_and_end.sql") + + def test_strip(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrStripOp(" "), "string_col") From d9d725cfbc3dca9e66b460cae4084e25162f2acf Mon Sep 17 00:00:00 2001 From: jialuoo Date: Mon, 25 Aug 2025 15:40:33 -0700 Subject: [PATCH 02/28] feat: Support args in series apply method (#2013) * feat: Support args in series apply method * resolve the comments --- bigframes/series.py | 32 ++++++++++++--- .../large/functions/test_managed_function.py | 28 +++++++++++++ .../large/functions/test_remote_function.py | 39 +++++++++++++++++++ 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/bigframes/series.py b/bigframes/series.py index 80952f38bc..c95b2ca37f 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1904,9 +1904,22 @@ def _groupby_values( ) def apply( - self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat" + self, + func, + by_row: typing.Union[typing.Literal["compat"], bool] = "compat", + *, + args: typing.Tuple = (), ) -> Series: - # TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs + # Note: This signature differs from pandas.Series.apply. Specifically, + # `args` is keyword-only and `by_row` is a custom parameter here. Full + # alignment would involve breaking changes. However, given that by_row + # is not frequently used, we defer any such changes until there is a + # clear need based on user feedback. + # + # See pandas docs for reference: + # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html + + # TODO(shobs, b/274645634): Support convert_dtype, **kwargs # is actually a ternary op if by_row not in ["compat", False]: @@ -1950,10 +1963,19 @@ def apply( raise # We are working with bigquery function at this point - result_series = self._apply_unary_op( - ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) - ) + if args: + result_series = self._apply_nary_op( + ops.NaryRemoteFunctionOp(function_def=func.udf_def), args + ) + # TODO(jialuo): Investigate why `_apply_nary_op` drops the series + # `name`. Manually reassigning it here as a temporary fix. + result_series.name = self.name + else: + result_series = self._apply_unary_op( + ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) + ) result_series = func._post_process_series(result_series) + return result_series def combine( diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 43fb322567..6f5ef5b534 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1121,3 +1121,31 @@ def _is_positive(s): finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False) + + +def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs): + try: + + with pytest.warns(bfe.PreviewWarning, match="udf is in preview."): + + @session.udf(dataset=dataset_id, name=prefixer.create_prefix()) + def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]: + return [str(x), str(y0), str(y1), str(y2)] + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = ( + scalars_df["int64_too"] + .apply(foo_list, args=(12.34, b"hello world", False)) + .to_pandas() + ) + pd_result = scalars_pandas_df["int64_too"].apply( + foo_list, args=(12.34, b"hello world", False) + ) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 1c44b7e5fb..cb61d3769c 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2979,3 +2979,42 @@ def _ten_times(x): finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs): + try: + + @session.remote_function( + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + ) + def foo(x: int, y: bool, z: float) -> str: + if y: + return f"{x}: y is True." + if z > 0.0: + return f"{x}: y is False and z is positive." + return f"{x}: y is False and z is non-positive." + + scalars_df, scalars_pandas_df = scalars_dfs + + args1 = (True, 10.0) + bf_result = scalars_df["int64_too"].apply(foo, args=args1).to_pandas() + pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args1) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + args2 = (False, -10.0) + foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function) + + bf_result = scalars_df["int64_too"].apply(foo_ref, args=args2).to_pandas() + pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args2) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(foo, session.bqclient, ignore_failures=False) From 8f2cad24a6a2fcacbfe49552861726be16ed41d9 Mon Sep 17 00:00:00 2001 From: jialuoo Date: Mon, 25 Aug 2025 16:53:14 -0700 Subject: [PATCH 03/28] test: Add unit test for has_conflict_input_type (#2022) --- .../functions/test_remote_function_utils.py | 65 ++++++++++++++++--- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index 0e4ca7a2ac..687c599985 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -217,13 +217,62 @@ def test_package_existed_helper(): assert not _utils._package_existed([], "pandas") +# Helper functions for signature inspection tests +def _func_one_arg_annotated(x: int) -> int: + """A function with one annotated arg and an annotated return type.""" + return x + + +def _func_one_arg_unannotated(x): + """A function with one unannotated arg and no return type annotation.""" + return x + + +def _func_two_args_annotated(x: int, y: str): + """A function with two annotated args and no return type annotation.""" + return f"{x}{y}" + + +def _func_two_args_unannotated(x, y): + """A function with two unannotated args and no return type annotation.""" + return f"{x}{y}" + + +def test_has_conflict_input_type_too_few_inputs(): + """Tests conflict when there are fewer input types than parameters.""" + signature = inspect.signature(_func_one_arg_annotated) + assert _utils.has_conflict_input_type(signature, input_types=[]) + + +def test_has_conflict_input_type_too_many_inputs(): + """Tests conflict when there are more input types than parameters.""" + signature = inspect.signature(_func_one_arg_annotated) + assert _utils.has_conflict_input_type(signature, input_types=[int, str]) + + +def test_has_conflict_input_type_type_mismatch(): + """Tests has_conflict_input_type with a conflicting type annotation.""" + signature = inspect.signature(_func_two_args_annotated) + + # The second type (bool) conflicts with the annotation (str). + assert _utils.has_conflict_input_type(signature, input_types=[int, bool]) + + +def test_has_conflict_input_type_no_conflict_annotated(): + """Tests that a matching, annotated signature is compatible.""" + signature = inspect.signature(_func_two_args_annotated) + assert not _utils.has_conflict_input_type(signature, input_types=[int, str]) + + +def test_has_conflict_input_type_no_conflict_unannotated(): + """Tests that a signature with no annotations is always compatible.""" + signature = inspect.signature(_func_two_args_unannotated) + assert not _utils.has_conflict_input_type(signature, input_types=[int, float]) + + def test_has_conflict_output_type_no_conflict(): """Tests has_conflict_output_type with type annotation.""" - # Helper functions with type annotation for has_conflict_output_type. - def _func_with_return_type(x: int) -> int: - return x - - signature = inspect.signature(_func_with_return_type) + signature = inspect.signature(_func_one_arg_annotated) assert _utils.has_conflict_output_type(signature, output_type=float) assert not _utils.has_conflict_output_type(signature, output_type=int) @@ -231,11 +280,7 @@ def _func_with_return_type(x: int) -> int: def test_has_conflict_output_type_no_annotation(): """Tests has_conflict_output_type without type annotation.""" - # Helper functions without type annotation for has_conflict_output_type. - def _func_without_return_type(x): - return x - - signature = inspect.signature(_func_without_return_type) + signature = inspect.signature(_func_one_arg_unannotated) assert not _utils.has_conflict_output_type(signature, output_type=int) assert not _utils.has_conflict_output_type(signature, output_type=float) From 9d4504be310d38b63515d67c0f60d2e48e68c7b5 Mon Sep 17 00:00:00 2001 From: jialuoo Date: Mon, 25 Aug 2025 17:27:47 -0700 Subject: [PATCH 04/28] feat: Support callable for dataframe mask method (#2020) --- bigframes/dataframe.py | 27 +++++++------ .../large/functions/test_managed_function.py | 38 ++++++++++++++++--- .../large/functions/test_remote_function.py | 25 ++++++++++-- tests/system/small/test_dataframe.py | 12 ++++++ 4 files changed, 81 insertions(+), 21 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 921893fb83..85b8245272 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2828,6 +2828,19 @@ def itertuples( for item in df.itertuples(index=index, name=name): yield item + def _apply_callable(self, condition): + """Executes the possible callable condition as needed.""" + if callable(condition): + # When it's a bigframes function. + if hasattr(condition, "bigframes_bigquery_function"): + return self.apply(condition, axis=1) + + # When it's a plain Python function. + return condition(self) + + # When it's not a callable. + return condition + def where(self, cond, other=None): if isinstance(other, bigframes.series.Series): raise ValueError("Seires is not a supported replacement type!") @@ -2839,16 +2852,8 @@ def where(self, cond, other=None): # Execute it with the DataFrame when cond or/and other is callable. # It can be either a plain python function or remote/managed function. - if callable(cond): - if hasattr(cond, "bigframes_bigquery_function"): - cond = self.apply(cond, axis=1) - else: - cond = cond(self) - if callable(other): - if hasattr(other, "bigframes_bigquery_function"): - other = self.apply(other, axis=1) - else: - other = other(self) + cond = self._apply_callable(cond) + other = self._apply_callable(other) aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. @@ -2899,7 +2904,7 @@ def where(self, cond, other=None): return result def mask(self, cond, other=None): - return self.where(~cond, other=other) + return self.where(~self._apply_callable(cond), other=other) def dropna( self, diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 6f5ef5b534..73335afa3c 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -965,7 +965,7 @@ def float_parser(row): ) -def test_managed_function_df_where(session, dataset_id, scalars_dfs): +def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -987,7 +987,7 @@ def is_sum_positive(a, b): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas() # Pandas doesn't support such case, use following as workaround. pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0) @@ -995,7 +995,7 @@ def is_sum_positive(a, b): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) - # Make sure the read_gbq_function path works for this function. + # Make sure the read_gbq_function path works for dataframe.where method. is_sum_positive_ref = session.read_gbq_function( function_name=is_sum_positive_mf.bigframes_bigquery_function ) @@ -1012,6 +1012,19 @@ def is_sum_positive(a, b): bf_result_gbq, pd_result_gbq, check_dtype=False ) + # Test callable condition in dataframe.mask method. + bf_result_gbq = bf_int64_df_filtered.mask( + is_sum_positive_ref, -bf_int64_df_filtered + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.mask( + pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets( @@ -1019,7 +1032,7 @@ def is_sum_positive(a, b): ) -def test_managed_function_df_where_series(session, dataset_id, scalars_dfs): +def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -1041,14 +1054,14 @@ def is_sum_positive_series(s): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas() pd_result = pd_int64_df_filtered.where(is_sum_positive_series) # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) - # Make sure the read_gbq_function path works for this function. + # Make sure the read_gbq_function path works for dataframe.where method. is_sum_positive_series_ref = session.read_gbq_function( function_name=is_sum_positive_series_mf.bigframes_bigquery_function, is_row_processor=True, @@ -1070,6 +1083,19 @@ def func_for_other(x): bf_result_gbq, pd_result_gbq, check_dtype=False ) + # Test callable condition in dataframe.mask method. + bf_result_gbq = bf_int64_df_filtered.mask( + is_sum_positive_series_ref, func_for_other + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.mask( + is_sum_positive_series, func_for_other + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets( diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index cb61d3769c..3c453a52a4 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2850,7 +2850,7 @@ def foo(x: int) -> int: @pytest.mark.flaky(retries=2, delay=120) -def test_remote_function_df_where(session, dataset_id, scalars_dfs): +def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -2873,7 +2873,7 @@ def is_sum_positive(a, b): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas() # Pandas doesn't support such case, use following as workaround. pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0) @@ -2881,6 +2881,14 @@ def is_sum_positive(a, b): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + # Test callable condition in dataframe.mask method. + bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas() + # Pandas doesn't support such case, use following as workaround. + pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets( @@ -2889,7 +2897,7 @@ def is_sum_positive(a, b): @pytest.mark.flaky(retries=2, delay=120) -def test_remote_function_df_where_series(session, dataset_id, scalars_dfs): +def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -2916,7 +2924,7 @@ def is_sum_positive_series(s): def func_for_other(x): return -x - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where( is_sum_positive_series, func_for_other ).to_pandas() @@ -2925,6 +2933,15 @@ def func_for_other(x): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + # Test callable condition in dataframe.mask method. + bf_result = bf_int64_df_filtered.mask( + is_sum_positive_series_mf, func_for_other + ).to_pandas() + pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets( diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 8a570ade45..51f4674ba4 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_mask_callable(scalars_df_index, scalars_pandas_df_index): + def is_positive(x): + return x > 0 + + bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]] + bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas() + pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_where_multi_column(scalars_df_index, scalars_pandas_df_index): # Test when a dataframe has multi-columns. columns = ["int64_col", "float64_col"] From 9ed00780bd7ec2f8c528dcc762bf8bb49fcc98ea Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 26 Aug 2025 11:18:13 -0700 Subject: [PATCH 05/28] chore: implement GeoStBufferOp, geo_st_centroid_op, geo_st_convexhull_op and MapOp for sqlglot compilers (#2021) --- .../sqlglot/expressions/unary_compiler.py | 32 +++++++++++++++++++ .../test_geo_st_buffer/out.sql | 13 ++++++++ .../test_geo_st_centroid/out.sql | 13 ++++++++ .../test_geo_st_convexhull/out.sql | 13 ++++++++ .../test_unary_compiler/test_map/out.sql | 13 ++++++++ .../expressions/test_unary_compiler.py | 30 +++++++++++++++++ 6 files changed, 114 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_buffer/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_centroid/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_convexhull/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_map/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index a5cffdc10a..3d527f2a2f 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -344,6 +344,27 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.func("ST_BOUNDARY", expr.expr) +@UNARY_OP_REGISTRATION.register(ops.GeoStBufferOp) +def _(op: ops.GeoStBufferOp, expr: TypedExpr) -> sge.Expression: + return sge.func( + "ST_BUFFER", + expr.expr, + sge.convert(op.buffer_radius), + sge.convert(op.num_seg_quarter_circle), + sge.convert(op.use_spheroid), + ) + + +@UNARY_OP_REGISTRATION.register(ops.geo_st_centroid_op) +def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: + return sge.func("ST_CENTROID", expr.expr) + + +@UNARY_OP_REGISTRATION.register(ops.geo_st_convexhull_op) +def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: + return sge.func("ST_CONVEXHULL", expr.expr) + + @UNARY_OP_REGISTRATION.register(ops.geo_st_geogfromtext_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.func("SAFE.ST_GEOGFROMTEXT", expr.expr) @@ -516,6 +537,17 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Lower(this=expr.expr) +@UNARY_OP_REGISTRATION.register(ops.MapOp) +def _(op: ops.MapOp, expr: TypedExpr) -> sge.Expression: + return sge.Case( + this=expr.expr, + ifs=[ + sge.If(this=sge.convert(key), true=sge.convert(value)) + for key, value in op.mappings + ], + ) + + @UNARY_OP_REGISTRATION.register(ops.minute_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="MINUTE"), expression=expr.expr) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_buffer/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_buffer/out.sql new file mode 100644 index 0000000000..9669c39a9f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_buffer/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `geography_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ST_BUFFER(`bfcol_0`, 1.0, 8.0, FALSE) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `geography_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_centroid/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_centroid/out.sql new file mode 100644 index 0000000000..97867318ad --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_centroid/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `geography_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ST_CENTROID(`bfcol_0`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `geography_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_convexhull/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_convexhull/out.sql new file mode 100644 index 0000000000..8bb5801173 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_geo_st_convexhull/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `geography_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ST_CONVEXHULL(`bfcol_0`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `geography_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_map/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_map/out.sql new file mode 100644 index 0000000000..a17d6584ce --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_map/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE `bfcol_0` WHEN 'value1' THEN 'mapped1' END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 5c51068ce7..2a3297a46c 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -174,6 +174,27 @@ def test_geo_st_boundary(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_geo_st_buffer(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["geography_col"]] + sql = _apply_unary_op(bf_df, ops.GeoStBufferOp(1.0, 8.0, False), "geography_col") + + snapshot.assert_match(sql, "out.sql") + + +def test_geo_st_centroid(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["geography_col"]] + sql = _apply_unary_op(bf_df, ops.geo_st_centroid_op, "geography_col") + + snapshot.assert_match(sql, "out.sql") + + +def test_geo_st_convexhull(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["geography_col"]] + sql = _apply_unary_op(bf_df, ops.geo_st_convexhull_op, "geography_col") + + snapshot.assert_match(sql, "out.sql") + + def test_geo_st_geogfromtext(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.geo_st_geogfromtext_op, "string_col") @@ -370,6 +391,15 @@ def test_lower(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_map(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op( + bf_df, ops.MapOp(mappings=(("value1", "mapped1"),)), "string_col" + ) + + snapshot.assert_match(sql, "out.sql") + + def test_lstrip(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrLstripOp(" "), "string_col") From cfa4b2a5e059164d1d961f69191fb7541e5882a1 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 26 Aug 2025 21:33:13 -0700 Subject: [PATCH 06/28] chore: implement StartsWithOp, EndsWithOp, StringSplitOp and ZfillOp for sqlglot compilers (#2027) --- .../sqlglot/expressions/unary_compiler.py | 58 +++++++++++++++++++ .../test_endswith/multiple_patterns.sql | 13 +++++ .../test_endswith/no_pattern.sql | 13 +++++ .../test_endswith/single_pattern.sql | 13 +++++ .../test_startswith/multiple_patterns.sql | 13 +++++ .../test_startswith/no_pattern.sql | 13 +++++ .../test_startswith/single_pattern.sql | 13 +++++ .../test_string_split/out.sql | 13 +++++ .../test_unary_compiler/test_zfill/out.sql | 17 ++++++ .../expressions/test_unary_compiler.py | 36 ++++++++++++ 10 files changed, 202 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 3d527f2a2f..98f1603be7 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +import functools import typing import pandas as pd @@ -292,6 +293,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr) +@UNARY_OP_REGISTRATION.register(ops.EndsWithOp) +def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression: + if not op.pat: + return sge.false() + + def to_endswith(pat: str) -> sge.Expression: + return sge.func("ENDS_WITH", expr.expr, sge.convert(pat)) + + conditions = [to_endswith(pat) for pat in op.pat] + return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) + + @UNARY_OP_REGISTRATION.register(ops.exp_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Case( @@ -633,6 +646,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StartsWithOp) +def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression: + if not op.pat: + return sge.false() + + def to_startswith(pat: str) -> sge.Expression: + return sge.func("STARTS_WITH", expr.expr, sge.convert(pat)) + + conditions = [to_startswith(pat) for pat in op.pat] + return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) + + @UNARY_OP_REGISTRATION.register(ops.StrStripOp) def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression: return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr) @@ -656,6 +681,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StringSplitOp) +def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression: + return sge.Split(this=expr.expr, expression=sge.convert(op.pat)) + + @UNARY_OP_REGISTRATION.register(ops.StrGetOp) def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression: return sge.Substring( @@ -808,3 +838,31 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: @UNARY_OP_REGISTRATION.register(ops.year_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) + + +@UNARY_OP_REGISTRATION.register(ops.ZfillOp) +def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression: + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ( + this=sge.Substring( + this=expr.expr, start=sge.convert(1), length=sge.convert(1) + ), + expression=sge.convert("-"), + ), + true=sge.Concat( + expressions=[ + sge.convert("-"), + sge.func( + "LPAD", + sge.Substring(this=expr.expr, start=sge.convert(1)), + sge.convert(op.width - 1), + sge.convert("0"), + ), + ] + ), + ) + ], + default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")), + ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql new file mode 100644 index 0000000000..f224471e79 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql new file mode 100644 index 0000000000..e9f61ddd7c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FALSE AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql new file mode 100644 index 0000000000..a4e259f0b2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ENDS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql new file mode 100644 index 0000000000..061b57e208 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + STARTS_WITH(`bfcol_0`, 'ab') OR STARTS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql new file mode 100644 index 0000000000..e9f61ddd7c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FALSE AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql new file mode 100644 index 0000000000..726ce05b8c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + STARTS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql new file mode 100644 index 0000000000..fea0d6eaf1 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + SPLIT(`bfcol_0`, ',') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql new file mode 100644 index 0000000000..e5d70ab44b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql @@ -0,0 +1,17 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN SUBSTRING(`bfcol_0`, 1, 1) = '-' + THEN CONCAT('-', LPAD(SUBSTRING(`bfcol_0`, 1), 9, '0')) + ELSE LPAD(`bfcol_0`, 10, '0') + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 2a3297a46c..f011721ee5 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -125,6 +125,18 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_endswith(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab",)), "string_col") + snapshot.assert_match(sql, "single_pattern.sql") + + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab", "cd")), "string_col") + snapshot.assert_match(sql, "multiple_patterns.sql") + + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=()), "string_col") + snapshot.assert_match(sql, "no_pattern.sql") + + def test_exp(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["float64_col"]] sql = _apply_unary_op(bf_df, ops.exp_op, "float64_col") @@ -501,6 +513,18 @@ def test_sqrt(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_startswith(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab",)), "string_col") + snapshot.assert_match(sql, "single_pattern.sql") + + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab", "cd")), "string_col") + snapshot.assert_match(sql, "multiple_patterns.sql") + + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=()), "string_col") + snapshot.assert_match(sql, "no_pattern.sql") + + def test_str_get(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrGetOp(1), "string_col") @@ -650,6 +674,12 @@ def test_sinh(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_string_split(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StringSplitOp(pat=","), "string_col") + snapshot.assert_match(sql, "out.sql") + + def test_tan(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["float64_col"]] sql = _apply_unary_op(bf_df, ops.tan_op, "float64_col") @@ -790,3 +820,9 @@ def test_year(scalar_types_df: bpd.DataFrame, snapshot): sql = _apply_unary_op(bf_df, ops.year_op, "timestamp_col") snapshot.assert_match(sql, "out.sql") + + +def test_zfill(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.ZfillOp(width=10), "string_col") + snapshot.assert_match(sql, "out.sql") From 6bf06a7e16f6aec9f19f748b07e9e0fb2c276a4a Mon Sep 17 00:00:00 2001 From: jialuoo Date: Wed, 27 Aug 2025 10:27:46 -0700 Subject: [PATCH 07/28] test: Add unit test for get_bigframes_function_name (#2031) --- .../functions/test_remote_function_utils.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index 687c599985..e46e04b427 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -76,6 +76,32 @@ def test_get_cloud_function_name(func_hash, session_id, uniq_suffix, expected_na assert result == expected_name +@pytest.mark.parametrize( + "function_hash, session_id, uniq_suffix, expected_name", + [ + ( + "hash123", + "session456", + None, + "bigframes_session456_hash123", + ), + ( + "hash789", + "sessionABC", + "suffixDEF", + "bigframes_sessionABC_hash789_suffixDEF", + ), + ], +) +def test_get_bigframes_function_name( + function_hash, session_id, uniq_suffix, expected_name +): + """Tests the construction of the BigQuery function name from its parts.""" + result = _utils.get_bigframes_function_name(function_hash, session_id, uniq_suffix) + + assert result == expected_name + + def test_get_updated_package_requirements_no_extra_package(): """Tests with no extra package.""" result = _utils.get_updated_package_requirements(capture_references=False) From fc44bc8f3a96daf6996623e9b6938975f4dfd6c5 Mon Sep 17 00:00:00 2001 From: jialuoo Date: Wed, 27 Aug 2025 10:59:50 -0700 Subject: [PATCH 08/28] test: Add unit test for get_hash (#2025) --- .../functions/test_remote_function_utils.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index e46e04b427..8ddd39d857 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -243,6 +243,78 @@ def test_package_existed_helper(): assert not _utils._package_existed([], "pandas") +def _function_add_one(x): + return x + 1 + + +def _function_add_two(x): + return x + 2 + + +@pytest.mark.parametrize( + "func1, func2, should_be_equal, description", + [ + ( + _function_add_one, + _function_add_one, + True, + "Identical functions should have the same hash.", + ), + ( + _function_add_one, + _function_add_two, + False, + "Different functions should have different hashes.", + ), + ], +) +def test_get_hash_without_package_requirements( + func1, func2, should_be_equal, description +): + """Tests function hashes without any requirements.""" + hash1 = _utils.get_hash(func1) + hash2 = _utils.get_hash(func2) + + if should_be_equal: + assert hash1 == hash2, f"FAILED: {description}" + else: + assert hash1 != hash2, f"FAILED: {description}" + + +@pytest.mark.parametrize( + "reqs1, reqs2, should_be_equal, description", + [ + ( + None, + ["pandas>=1.0"], + False, + "Hash with or without requirements should differ from hash.", + ), + ( + ["pandas", "numpy", "scikit-learn"], + ["numpy", "scikit-learn", "pandas"], + True, + "Same requirements should produce the same hash.", + ), + ( + ["pandas==1.0"], + ["pandas==2.0"], + False, + "Different requirement versions should produce different hashes.", + ), + ], +) +def test_get_hash_with_package_requirements(reqs1, reqs2, should_be_equal, description): + """Tests how package requirements affect the final hash.""" + hash1 = _utils.get_hash(_function_add_one, package_requirements=reqs1) + hash2 = _utils.get_hash(_function_add_one, package_requirements=reqs2) + + if should_be_equal: + assert hash1 == hash2, f"FAILED: {description}" + else: + assert hash1 != hash2, f"FAILED: {description}" + + # Helper functions for signature inspection tests def _func_one_arg_annotated(x: int) -> int: """A function with one annotated arg and an annotated return type.""" From 2c72c56fb5893eb01d5aec6273d11945c9c532c5 Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Wed, 27 Aug 2025 11:52:13 -0700 Subject: [PATCH 09/28] feat: add parameter shuffle for ml.model_selection.train_test_split (#2030) * feat: add parameter shuffle for ml.model_selection.train_test_split * mypy * rename --- bigframes/ml/model_selection.py | 26 ++++++- bigframes/ml/utils.py | 24 ++++++ tests/system/small/ml/test_model_selection.py | 74 +++++++++++++++++++ 3 files changed, 120 insertions(+), 4 deletions(-) diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index abb4b0f26c..ca089bb551 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -18,6 +18,7 @@ import inspect +from itertools import chain import time from typing import cast, Generator, List, Optional, Union @@ -36,12 +37,9 @@ def train_test_split( train_size: Union[float, None] = None, random_state: Union[int, None] = None, stratify: Union[bpd.Series, None] = None, + shuffle: bool = True, ) -> List[Union[bpd.DataFrame, bpd.Series]]: - # TODO(garrettwu): scikit-learn throws an error when the dataframes don't have the same - # number of rows. We probably want to do something similar. Now the implementation is based - # on index. We'll move to based on ordering first. - if test_size is None: if train_size is None: test_size = 0.25 @@ -61,6 +59,26 @@ def train_test_split( f"The sum of train_size and test_size exceeds 1.0. train_size: {train_size}. test_size: {test_size}" ) + if not shuffle: + if stratify is not None: + raise ValueError( + "Stratified train/test split is not implemented for shuffle=False" + ) + bf_arrays = list(utils.batch_convert_to_bf_equivalent(*arrays)) + + total_rows = len(bf_arrays[0]) + train_rows = int(total_rows * train_size) + test_rows = total_rows - train_rows + + return list( + chain.from_iterable( + [ + [bf_array.head(train_rows), bf_array.tail(test_rows)] + for bf_array in bf_arrays + ] + ) + ) + dfs = list(utils.batch_convert_to_dataframe(*arrays)) def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]: diff --git a/bigframes/ml/utils.py b/bigframes/ml/utils.py index 5c02789576..80630c4f81 100644 --- a/bigframes/ml/utils.py +++ b/bigframes/ml/utils.py @@ -79,6 +79,30 @@ def batch_convert_to_series( ) +def batch_convert_to_bf_equivalent( + *input: ArrayType, session: Optional[Session] = None +) -> Generator[Union[bpd.DataFrame, bpd.Series], None, None]: + """Converts the input to BigFrames DataFrame or Series. + + Args: + session: + The session to convert local pandas instances to BigFrames counter-parts. + It is not used if the input itself is already a BigFrame data frame or series. + + """ + _validate_sessions(*input, session=session) + + for frame in input: + if isinstance(frame, bpd.DataFrame) or isinstance(frame, pd.DataFrame): + yield convert.to_bf_dataframe(frame, default_index=None, session=session) + elif isinstance(frame, bpd.Series) or isinstance(frame, pd.Series): + yield convert.to_bf_series( + _get_only_column(frame), default_index=None, session=session + ) + else: + raise ValueError(f"Unsupported type: {type(frame)}") + + def _validate_sessions(*input: ArrayType, session: Optional[Session]): session_ids = set( i._session.session_id diff --git a/tests/system/small/ml/test_model_selection.py b/tests/system/small/ml/test_model_selection.py index c1a1e073b9..ebce6e405a 100644 --- a/tests/system/small/ml/test_model_selection.py +++ b/tests/system/small/ml/test_model_selection.py @@ -13,12 +13,14 @@ # limitations under the License. import math +from typing import cast import pandas as pd import pytest from bigframes.ml import model_selection import bigframes.pandas as bpd +import bigframes.session @pytest.mark.parametrize( @@ -219,6 +221,78 @@ def test_train_test_split_seeded_correct_rows( ) +def test_train_test_split_no_shuffle_correct_shape( + penguins_df_default_index: bpd.DataFrame, +): + X = penguins_df_default_index[["species"]] + y = penguins_df_default_index["body_mass_g"] + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, shuffle=False + ) + assert isinstance(X_train, bpd.DataFrame) + assert isinstance(X_test, bpd.DataFrame) + assert isinstance(y_train, bpd.Series) + assert isinstance(y_test, bpd.Series) + + assert X_train.shape == (258, 1) + assert X_test.shape == (86, 1) + assert y_train.shape == (258,) + assert y_test.shape == (86,) + + +def test_train_test_split_no_shuffle_correct_rows( + session: bigframes.session.Session, penguins_pandas_df_default_index: bpd.DataFrame +): + # Note that we're using `penguins_pandas_df_default_index` as this test depends + # on a stable row order being present end to end + # filter down to the chunkiest penguins, to keep our test code a reasonable size + all_data = penguins_pandas_df_default_index[ + penguins_pandas_df_default_index.body_mass_g > 5500 + ].sort_index() + + # Note that bigframes loses the index if it doesn't have a name + all_data.index.name = "rowindex" + + df = session.read_pandas(all_data) + + X = df[ + [ + "species", + "island", + "culmen_length_mm", + ] + ] + y = df["body_mass_g"] + X_train, X_test, y_train, y_test = model_selection.train_test_split( + X, y, shuffle=False + ) + + X_train_pd = cast(bpd.DataFrame, X_train).to_pandas() + X_test_pd = cast(bpd.DataFrame, X_test).to_pandas() + y_train_pd = cast(bpd.Series, y_train).to_pandas() + y_test_pd = cast(bpd.Series, y_test).to_pandas() + + total_rows = len(all_data) + train_size = 0.75 + train_rows = int(total_rows * train_size) + test_rows = total_rows - train_rows + + expected_X_train = all_data.head(train_rows)[ + ["species", "island", "culmen_length_mm"] + ] + expected_y_train = all_data.head(train_rows)["body_mass_g"] + + expected_X_test = all_data.tail(test_rows)[ + ["species", "island", "culmen_length_mm"] + ] + expected_y_test = all_data.tail(test_rows)["body_mass_g"] + + pd.testing.assert_frame_equal(X_train_pd, expected_X_train) + pd.testing.assert_frame_equal(X_test_pd, expected_X_test) + pd.testing.assert_series_equal(y_train_pd, expected_y_train) + pd.testing.assert_series_equal(y_test_pd, expected_y_test) + + @pytest.mark.parametrize( ("train_size", "test_size"), [ From ba0d23b59c44ba5a46ace8182ad0e0cfc703b3ab Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 27 Aug 2025 12:30:23 -0700 Subject: [PATCH 10/28] feat: support multi-column assignment for DataFrame (#2028) * feat: support multi-column assignment for DataFrame * fix lint * fix mypy * fix Sequence type checking bug --- bigframes/dataframe.py | 41 +++++++++++-- tests/system/small/test_dataframe.py | 61 +++++++++++++++++++ .../bigframes_vendored/pandas/core/frame.py | 34 ++++++++++- 3 files changed, 131 insertions(+), 5 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 85b8245272..b2947f7493 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -26,6 +26,7 @@ import traceback import typing from typing import ( + Any, Callable, Dict, Hashable, @@ -91,6 +92,7 @@ import bigframes.session SingleItemValue = Union[bigframes.series.Series, int, float, str, Callable] + MultiItemValue = Union["DataFrame", Sequence[int | float | str | Callable]] LevelType = typing.Hashable LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]] @@ -884,8 +886,13 @@ def __delitem__(self, key: str): df = self.drop(columns=[key]) self._set_block(df._get_block()) - def __setitem__(self, key: str, value: SingleItemValue): - df = self._assign_single_item(key, value) + def __setitem__( + self, key: str | list[str], value: SingleItemValue | MultiItemValue + ): + if isinstance(key, list): + df = self._assign_multi_items(key, value) + else: + df = self._assign_single_item(key, value) self._set_block(df._get_block()) __setitem__.__doc__ = inspect.getdoc(vendored_pandas_frame.DataFrame.__setitem__) @@ -2212,7 +2219,7 @@ def assign(self, **kwargs) -> DataFrame: def _assign_single_item( self, k: str, - v: SingleItemValue, + v: SingleItemValue | MultiItemValue, ) -> DataFrame: if isinstance(v, bigframes.series.Series): return self._assign_series_join_on_index(k, v) @@ -2230,7 +2237,33 @@ def _assign_single_item( elif utils.is_list_like(v): return self._assign_single_item_listlike(k, v) else: - return self._assign_scalar(k, v) + return self._assign_scalar(k, v) # type: ignore + + def _assign_multi_items( + self, + k: list[str], + v: SingleItemValue | MultiItemValue, + ) -> DataFrame: + value_sources: Sequence[Any] = [] + if isinstance(v, DataFrame): + value_sources = [v[col] for col in v.columns] + elif isinstance(v, bigframes.series.Series): + # For behavior consistency with Pandas. + raise ValueError("Columns must be same length as key") + elif isinstance(v, Sequence): + value_sources = v + else: + # We assign the same scalar value to all target columns. + value_sources = [v] * len(k) + + if len(value_sources) != len(k): + raise ValueError("Columns must be same length as key") + + # Repeatedly assign columns in order. + result = self._assign_single_item(k[0], value_sources[0]) + for target, source in zip(k[1:], value_sources[1:]): + result = result._assign_single_item(target, source) + return result def _assign_single_item_listlike(self, k: str, v: Sequence) -> DataFrame: given_rows = len(v) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 51f4674ba4..c7f9627531 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -1138,6 +1138,67 @@ def test_assign_new_column_w_setitem_list_error(scalars_dfs): bf_df["new_col"] = [1, 2, 3] +@pytest.mark.parametrize( + ("key", "value"), + [ + pytest.param(["int64_col", "int64_too"], 1, id="scalar_to_existing_column"), + pytest.param( + ["int64_col", "int64_too"], [1, 2], id="sequence_to_existing_column" + ), + pytest.param( + ["int64_col", "new_col"], [1, 2], id="sequence_to_partial_new_column" + ), + pytest.param( + ["new_col", "new_col_too"], [1, 2], id="sequence_to_full_new_column" + ), + ], +) +def test_setitem_multicolumn_with_literals(scalars_dfs, key, value): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.copy() + pd_result = scalars_pandas_df.copy() + + bf_result[key] = value + pd_result[key] = value + + pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False) + + +def test_setitem_multicolumn_with_literals_different_lengths_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col", "int64_too"]] = [1] + + +def test_setitem_multicolumn_with_dataframes(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.copy() + pd_result = scalars_pandas_df.copy() + + bf_result[["int64_col", "int64_too"]] = bf_result[["int64_too", "int64_col"]] / 2 + pd_result[["int64_col", "int64_too"]] = pd_result[["int64_too", "int64_col"]] / 2 + + pd.testing.assert_frame_equal(pd_result, bf_result.to_pandas(), check_dtype=False) + + +def test_setitem_multicolumn_with_dataframes_series_on_rhs_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col", "int64_too"]] = bf_result["int64_col"] / 2 + + +def test_setitem_multicolumn_with_dataframes_different_lengths_raise_error(scalars_dfs): + scalars_df, _ = scalars_dfs + bf_result = scalars_df.copy() + + with pytest.raises(ValueError): + bf_result[["int64_col"]] = bf_result[["int64_col", "int64_too"]] / 2 + + def test_assign_existing_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs kwargs = {"int64_col": 2} diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 44ca558070..953ece9beb 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -7626,11 +7626,43 @@ def __setitem__(self, key, value): [3 rows x 5 columns] + You can assign a scalar to multiple columns. + + >>> df[["age", "new_age"]] = 25 + >>> df + name age location country new_age + 0 alpha 25 WA USA 25 + 1 beta 25 NY USA 25 + 2 gamma 25 CA USA 25 + + [3 rows x 5 columns] + + You can use a sequence of scalars for assignment of multiple columns: + + >>> df[["age", "is_happy"]] = [20, True] + >>> df + name age location country new_age is_happy + 0 alpha 20 WA USA 25 True + 1 beta 20 NY USA 25 True + 2 gamma 20 CA USA 25 True + + [3 rows x 6 columns] + + You can use a dataframe for assignment of multiple columns: + >>> df[["age", "new_age"]] = df[["new_age", "age"]] + >>> df + name age location country new_age is_happy + 0 alpha 25 WA USA 20 True + 1 beta 25 NY USA 20 True + 2 gamma 25 CA USA 20 True + + [3 rows x 6 columns] + Args: key (column index): It can be a new column to be inserted, or an existing column to be modified. - value (scalar or Series): + value (scalar, Sequence, DataFrame, or Series): Value to be assigned to the column """ raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) From c0b54f03849ee3115413670e690e68f3ef10f2ec Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Wed, 27 Aug 2025 13:48:55 -0700 Subject: [PATCH 11/28] feat: Support string matching in local executor (#2032) --- bigframes/core/compile/polars/compiler.py | 28 ++++++++ bigframes/session/polars_executor.py | 12 +++- tests/system/small/engines/test_strings.py | 77 ++++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 tests/system/small/engines/test_strings.py diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 1ba76dee5b..1bfbe0f734 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -301,6 +301,34 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: assert isinstance(op, string_ops.StrConcatOp) return pl.concat_str(l_input, r_input) + @compile_op.register(string_ops.StrContainsOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, string_ops.StrContainsOp) + return input.str.contains(pattern=op.pat, literal=True) + + @compile_op.register(string_ops.StrContainsRegexOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, string_ops.StrContainsRegexOp) + return input.str.contains(pattern=op.pat, literal=False) + + @compile_op.register(string_ops.StartsWithOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, string_ops.StartsWithOp) + if len(op.pat) == 1: + return input.str.starts_with(op.pat[0]) + else: + return pl.any_horizontal( + *(input.str.starts_with(pat) for pat in op.pat) + ) + + @compile_op.register(string_ops.EndsWithOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, string_ops.EndsWithOp) + if len(op.pat) == 1: + return input.str.ends_with(op.pat[0]) + else: + return pl.any_horizontal(*(input.str.ends_with(pat) for pat in op.pat)) + @compile_op.register(dt_ops.StrftimeOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, dt_ops.StrftimeOp) diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 6e3f0ca10f..b93d31d255 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -21,7 +21,13 @@ from bigframes.core import array_value, bigframe_node, expression, local_data, nodes import bigframes.operations from bigframes.operations import aggregations as agg_ops -from bigframes.operations import bool_ops, comparison_ops, generic_ops, numeric_ops +from bigframes.operations import ( + bool_ops, + comparison_ops, + generic_ops, + numeric_ops, + string_ops, +) from bigframes.session import executor, semi_executor if TYPE_CHECKING: @@ -69,6 +75,10 @@ generic_ops.IsInOp, generic_ops.IsNullOp, generic_ops.NotNullOp, + string_ops.StartsWithOp, + string_ops.EndsWithOp, + string_ops.StrContainsOp, + string_ops.StrContainsRegexOp, ) _COMPATIBLE_AGG_OPS = ( agg_ops.SizeOp, diff --git a/tests/system/small/engines/test_strings.py b/tests/system/small/engines/test_strings.py new file mode 100644 index 0000000000..cbab517ef0 --- /dev/null +++ b/tests/system/small/engines/test_strings.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_str_contains(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.StrContainsOp("(?i)hEllo").as_expr("string_col"), + ops.StrContainsOp("Hello").as_expr("string_col"), + ops.StrContainsOp("T").as_expr("string_col"), + ops.StrContainsOp(".*").as_expr("string_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_str_contains_regex( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.StrContainsRegexOp("(?i)hEllo").as_expr("string_col"), + ops.StrContainsRegexOp("Hello").as_expr("string_col"), + ops.StrContainsRegexOp("T").as_expr("string_col"), + ops.StrContainsRegexOp(".*").as_expr("string_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_str_startswith(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.StartsWithOp("He").as_expr("string_col"), + ops.StartsWithOp("llo").as_expr("string_col"), + ops.StartsWithOp(("He", "T", "ca")).as_expr("string_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_str_endswith(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.EndsWithOp("!").as_expr("string_col"), + ops.EndsWithOp("llo").as_expr("string_col"), + ops.EndsWithOp(("He", "T", "ca")).as_expr("string_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) From 935af107ef98837fb2b81d72185d0b6a9e09fbcf Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Wed, 27 Aug 2025 14:06:25 -0700 Subject: [PATCH 12/28] fix: Fix scalar op lowering tree walk (#2029) --- bigframes/core/expression.py | 5 +++++ bigframes/core/rewrite/op_lowering.py | 2 +- tests/system/small/engines/test_generic_ops.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 7b20e430ff..0e94193bd3 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -253,6 +253,11 @@ def is_identity(self) -> bool: def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: ... + def bottom_up(self, t: Callable[[Expression], Expression]) -> Expression: + expr = self.transform_children(lambda child: child.bottom_up(t)) + expr = t(expr) + return expr + def walk(self) -> Generator[Expression, None, None]: yield self for child in self.children: diff --git a/bigframes/core/rewrite/op_lowering.py b/bigframes/core/rewrite/op_lowering.py index a64a4cc8c4..6473c3bf8a 100644 --- a/bigframes/core/rewrite/op_lowering.py +++ b/bigframes/core/rewrite/op_lowering.py @@ -44,7 +44,7 @@ def lower_expr_step(expr: expression.Expression) -> expression.Expression: return maybe_rule.lower(expr) return expr - return lower_expr_step(expr.transform_children(lower_expr_step)) + return expr.bottom_up(lower_expr_step) def lower_node(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: if isinstance( diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index 1d28c335a6..8deef3638e 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -423,3 +423,18 @@ def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine): ) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_isin_op_nested_filter( + scalars_array_value: array_value.ArrayValue, engine +): + isin_clause = ops.IsInOp((1, 2, 3)).as_expr(expression.deref("int64_col")) + filter_clause = ops.invert_op.as_expr( + ops.or_op.as_expr( + expression.deref("bool_col"), ops.invert_op.as_expr(isin_clause) + ) + ) + arr = scalars_array_value.filter(filter_clause) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) From 5b8bdec771324fdb128c5fee7c4b376bef19d1a1 Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Wed, 27 Aug 2025 18:04:20 -0700 Subject: [PATCH 13/28] chore: fix typos in bigframes.ml (#2035) * Fix: Fix typos in bigframes/ml * Fix: Fix mypy error in dataframe.py --- bigframes/dataframe.py | 1 + bigframes/ml/cluster.py | 2 +- bigframes/ml/forecasting.py | 2 +- bigframes/ml/llm.py | 2 +- bigframes/ml/model_selection.py | 2 +- bigframes/ml/preprocessing.py | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index b2947f7493..85760d94bc 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -582,6 +582,7 @@ def __getitem__( # Index of column labels can be treated the same as a sequence of column labels. pandas.Index, bigframes.series.Series, + slice, ], ): # No return type annotations (like pandas) as type cannot always be determined statically # NOTE: This implements the operations described in diff --git a/bigframes/ml/cluster.py b/bigframes/ml/cluster.py index cd27357680..9ce4649c5e 100644 --- a/bigframes/ml/cluster.py +++ b/bigframes/ml/cluster.py @@ -59,7 +59,7 @@ def __init__( warm_start: bool = False, ): self.n_clusters = n_clusters - # allow the alias to be compatible with sklean + # allow the alias to be compatible with sklearn self.init = "kmeans++" if init == "k-means++" else init self.init_col = init_col self.distance_type = distance_type diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 2e93e5485f..d26abdfa71 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -211,7 +211,7 @@ def _fit( Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series, or pandas.core.frame.DataFrame or pandas.core.series.Series): - A dataframe or series of trainging timestamp. + A dataframe or series of training timestamp. y (bigframes.dataframe.DataFrame, or bigframes.series.Series, or pandas.core.frame.DataFrame, or pandas.core.series.Series): Target values for training. diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 11861c786e..eba15909b4 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -834,7 +834,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator: class Claude3TextGenerator(base.RetriableRemotePredictor): """Claude3 text generator LLM model. - Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. + Go to Google Cloud Console -> Vertex AI -> Model Garden page to enable the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions .. note:: diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index ca089bb551..6eba4f81c2 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -82,7 +82,7 @@ def train_test_split( dfs = list(utils.batch_convert_to_dataframe(*arrays)) def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFrame]: - """Split a single DF accoding to the stratify Series.""" + """Split a single DF according to the stratify Series.""" stratify = stratify.rename("bigframes_stratify_col") # avoid name conflicts merged_df = df.join(stratify.to_frame(), how="outer") diff --git a/bigframes/ml/preprocessing.py b/bigframes/ml/preprocessing.py index 0448d8544a..2e8dc64a53 100644 --- a/bigframes/ml/preprocessing.py +++ b/bigframes/ml/preprocessing.py @@ -434,7 +434,7 @@ def _compile_to_sql( if columns is None: columns = X.columns drop = self.drop if self.drop is not None else "none" - # minus one here since BQML's inplimentation always includes index 0, and top_k is on top of that. + # minus one here since BQML's implementation always includes index 0, and top_k is on top of that. top_k = ( (self.max_categories - 1) if self.max_categories is not None From 209d0d48956fafc3cf40cded2d8a2468eefd8813 Mon Sep 17 00:00:00 2001 From: Huan Chen <142538604+Genesis929@users.noreply.github.com> Date: Wed, 27 Aug 2025 18:09:22 -0700 Subject: [PATCH 14/28] chore: add bq execution time to benchmark (#2033) * chore: update benchmark metrics * fix metric calculation --- bigframes/session/metrics.py | 7 +- scripts/run_and_publish_benchmark.py | 134 +++++++++++---------------- testing/constraints-3.11.txt | 2 +- 3 files changed, 60 insertions(+), 83 deletions(-) diff --git a/bigframes/session/metrics.py b/bigframes/session/metrics.py index 8ec8d525cc..8d43a83d73 100644 --- a/bigframes/session/metrics.py +++ b/bigframes/session/metrics.py @@ -45,12 +45,17 @@ def count_job_stats( bytes_processed = getattr(row_iterator, "total_bytes_processed", 0) or 0 query_char_count = len(getattr(row_iterator, "query", "") or "") slot_millis = getattr(row_iterator, "slot_millis", 0) or 0 - exec_seconds = 0.0 + created = getattr(row_iterator, "created", None) + ended = getattr(row_iterator, "ended", None) + exec_seconds = ( + (ended - created).total_seconds() if created and ended else 0.0 + ) self.execution_count += 1 self.query_char_count += query_char_count self.bytes_processed += bytes_processed self.slot_millis += slot_millis + self.execution_secs += exec_seconds elif query_job.configuration.dry_run: query_char_count = len(query_job.query) diff --git a/scripts/run_and_publish_benchmark.py b/scripts/run_and_publish_benchmark.py index 248322f619..859d68e60e 100644 --- a/scripts/run_and_publish_benchmark.py +++ b/scripts/run_and_publish_benchmark.py @@ -84,43 +84,36 @@ def collect_benchmark_result( path = pathlib.Path(benchmark_path) try: results_dict: Dict[str, List[Union[int, float, None]]] = {} - bytes_files = sorted(path.rglob("*.bytesprocessed")) - millis_files = sorted(path.rglob("*.slotmillis")) - bq_seconds_files = sorted(path.rglob("*.bq_exec_time_seconds")) + # Use local_seconds_files as the baseline local_seconds_files = sorted(path.rglob("*.local_exec_time_seconds")) - query_char_count_files = sorted(path.rglob("*.query_char_count")) - error_files = sorted(path.rglob("*.error")) - - if not ( - len(millis_files) - == len(bq_seconds_files) - <= len(bytes_files) - == len(query_char_count_files) - == len(local_seconds_files) - ): - raise ValueError( - "Mismatch in the number of report files for bytes, millis, seconds and query char count: \n" - f"millis_files: {len(millis_files)}\n" - f"bq_seconds_files: {len(bq_seconds_files)}\n" - f"bytes_files: {len(bytes_files)}\n" - f"query_char_count_files: {len(query_char_count_files)}\n" - f"local_seconds_files: {len(local_seconds_files)}\n" - ) - - has_full_metrics = len(bq_seconds_files) == len(local_seconds_files) - - for idx in range(len(local_seconds_files)): - query_char_count_file = query_char_count_files[idx] - local_seconds_file = local_seconds_files[idx] - bytes_file = bytes_files[idx] - filename = query_char_count_file.relative_to(path).with_suffix("") - if filename != local_seconds_file.relative_to(path).with_suffix( - "" - ) or filename != bytes_file.relative_to(path).with_suffix(""): - raise ValueError( - "File name mismatch among query_char_count, bytes and seconds reports." - ) + benchmarks_with_missing_files = [] + + for local_seconds_file in local_seconds_files: + base_name = local_seconds_file.name.removesuffix(".local_exec_time_seconds") + base_path = local_seconds_file.parent / base_name + filename = base_path.relative_to(path) + + # Construct paths for other metric files + bytes_file = pathlib.Path(f"{base_path}.bytesprocessed") + millis_file = pathlib.Path(f"{base_path}.slotmillis") + bq_seconds_file = pathlib.Path(f"{base_path}.bq_exec_time_seconds") + query_char_count_file = pathlib.Path(f"{base_path}.query_char_count") + + # Check if all corresponding files exist + missing_files = [] + if not bytes_file.exists(): + missing_files.append(bytes_file.name) + if not millis_file.exists(): + missing_files.append(millis_file.name) + if not bq_seconds_file.exists(): + missing_files.append(bq_seconds_file.name) + if not query_char_count_file.exists(): + missing_files.append(query_char_count_file.name) + + if missing_files: + benchmarks_with_missing_files.append((str(filename), missing_files)) + continue with open(query_char_count_file, "r") as file: lines = file.read().splitlines() @@ -135,26 +128,13 @@ def collect_benchmark_result( lines = file.read().splitlines() total_bytes = sum(int(line) for line in lines) / iterations - if not has_full_metrics: - total_slot_millis = None - bq_seconds = None - else: - millis_file = millis_files[idx] - bq_seconds_file = bq_seconds_files[idx] - if filename != millis_file.relative_to(path).with_suffix( - "" - ) or filename != bq_seconds_file.relative_to(path).with_suffix(""): - raise ValueError( - "File name mismatch among query_char_count, bytes, millis, and seconds reports." - ) - - with open(millis_file, "r") as file: - lines = file.read().splitlines() - total_slot_millis = sum(int(line) for line in lines) / iterations + with open(millis_file, "r") as file: + lines = file.read().splitlines() + total_slot_millis = sum(int(line) for line in lines) / iterations - with open(bq_seconds_file, "r") as file: - lines = file.read().splitlines() - bq_seconds = sum(float(line) for line in lines) / iterations + with open(bq_seconds_file, "r") as file: + lines = file.read().splitlines() + bq_seconds = sum(float(line) for line in lines) / iterations results_dict[str(filename)] = [ query_count, @@ -207,13 +187,9 @@ def collect_benchmark_result( f"{index} - query count: {row['Query_Count']}," + f" query char count: {row['Query_Char_Count']}," + f" bytes processed sum: {row['Bytes_Processed']}," - + (f" slot millis sum: {row['Slot_Millis']}," if has_full_metrics else "") - + f" local execution time: {formatted_local_exec_time} seconds" - + ( - f", bigquery execution time: {round(row['BigQuery_Execution_Time_Sec'], 1)} seconds" - if has_full_metrics - else "" - ) + + f" slot millis sum: {row['Slot_Millis']}," + + f" local execution time: {formatted_local_exec_time}" + + f", bigquery execution time: {round(row['BigQuery_Execution_Time_Sec'], 1)} seconds" ) geometric_mean_queries = geometric_mean_excluding_zeros( @@ -239,30 +215,26 @@ def collect_benchmark_result( f"---Geometric mean of queries: {geometric_mean_queries}," + f" Geometric mean of queries char counts: {geometric_mean_query_char_count}," + f" Geometric mean of bytes processed: {geometric_mean_bytes}," - + ( - f" Geometric mean of slot millis: {geometric_mean_slot_millis}," - if has_full_metrics - else "" - ) + + f" Geometric mean of slot millis: {geometric_mean_slot_millis}," + f" Geometric mean of local execution time: {geometric_mean_local_seconds} seconds" - + ( - f", Geometric mean of BigQuery execution time: {geometric_mean_bq_seconds} seconds---" - if has_full_metrics - else "" - ) + + f", Geometric mean of BigQuery execution time: {geometric_mean_bq_seconds} seconds---" ) - error_message = ( - "\n" - + "\n".join( - [ - f"Failed: {error_file.relative_to(path).with_suffix('')}" - for error_file in error_files - ] + all_errors: List[str] = [] + if error_files: + all_errors.extend( + f"Failed: {error_file.relative_to(path).with_suffix('')}" + for error_file in error_files ) - if error_files - else None - ) + if ( + benchmarks_with_missing_files + and os.getenv("BENCHMARK_AND_PUBLISH", "false") == "true" + ): + all_errors.extend( + f"Missing files for benchmark '{name}': {files}" + for name, files in benchmarks_with_missing_files + ) + error_message = "\n" + "\n".join(all_errors) if all_errors else None return ( benchmark_metrics.reset_index().rename(columns={"index": "Benchmark_Name"}), error_message, diff --git a/testing/constraints-3.11.txt b/testing/constraints-3.11.txt index 8fd20d453b..8c274bd9fb 100644 --- a/testing/constraints-3.11.txt +++ b/testing/constraints-3.11.txt @@ -152,7 +152,7 @@ google-auth==2.38.0 google-auth-httplib2==0.2.0 google-auth-oauthlib==1.2.2 google-cloud-aiplatform==1.106.0 -google-cloud-bigquery==3.35.1 +google-cloud-bigquery==3.36.0 google-cloud-bigquery-connection==1.18.3 google-cloud-bigquery-storage==2.32.0 google-cloud-core==2.4.3 From b0d620bbe8227189bbdc2ba5a913b03c70575296 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 28 Aug 2025 13:34:16 -0700 Subject: [PATCH 15/28] fix: read_csv fails when check file size for wildcard gcs files (#2019) --- bigframes/session/__init__.py | 22 ++++++++++++++++++---- tests/system/small/test_session.py | 26 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 10a112c779..66b0196286 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -18,6 +18,8 @@ from collections import abc import datetime +import fnmatch +import inspect import logging import os import secrets @@ -1344,12 +1346,24 @@ def read_json( def _check_file_size(self, filepath: str): max_size = 1024 * 1024 * 1024 # 1 GB in bytes if filepath.startswith("gs://"): # GCS file path + bucket_name, blob_path = filepath.split("/", 3)[2:] + client = storage.Client() - bucket_name, blob_name = filepath.split("/", 3)[2:] bucket = client.bucket(bucket_name) - blob = bucket.blob(blob_name) - blob.reload() - file_size = blob.size + + list_blobs_params = inspect.signature(bucket.list_blobs).parameters + if "match_glob" in list_blobs_params: + # Modern, efficient method for new library versions + matching_blobs = bucket.list_blobs(match_glob=blob_path) + file_size = sum(blob.size for blob in matching_blobs) + else: + # Fallback method for older library versions + prefix = blob_path.split("*", 1)[0] + all_blobs = bucket.list_blobs(prefix=prefix) + matching_blobs = [ + blob for blob in all_blobs if fnmatch.fnmatch(blob.name, blob_path) + ] + file_size = sum(blob.size for blob in matching_blobs) elif os.path.exists(filepath): # local file path file_size = os.path.getsize(filepath) else: diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index a04da64af0..f0a6302c7b 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -1287,6 +1287,32 @@ def test_read_csv_raises_error_for_invalid_index_col( session.read_csv(path, engine="bigquery", index_col=index_col) +def test_read_csv_for_gcs_wildcard_path(session, df_and_gcs_csv): + scalars_pandas_df, path = df_and_gcs_csv + path = path.replace(".csv", "*.csv") + + index_col = "rowindex" + bf_df = session.read_csv(path, engine="bigquery", index_col=index_col) + + # Convert default pandas dtypes to match BigQuery DataFrames dtypes. + # Also, `expand=True` is needed to read from wildcard paths. See details: + # https://github.com/fsspec/gcsfs/issues/616, + if not pd.__version__.startswith("1."): + storage_options = {"expand": True} + else: + storage_options = None + pd_df = session.read_csv( + path, + index_col=index_col, + dtype=scalars_pandas_df.dtypes.to_dict(), + storage_options=storage_options, + ) + + assert bf_df.shape == pd_df.shape + assert bf_df.columns.tolist() == pd_df.columns.tolist() + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df.to_pandas()) + + def test_read_csv_for_names(session, df_and_gcs_csv_for_two_columns): _, path = df_and_gcs_csv_for_two_columns From 3c87e9725b4dd19f22c77a85aca3215b425e5526 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 28 Aug 2025 13:38:11 -0700 Subject: [PATCH 16/28] test: add unit test for remap_variables (#2037) --- .../{test_rewrite.py => rewrite/conftest.py} | 36 ++--- tests/unit/core/rewrite/test_identifiers.py | 132 ++++++++++++++++++ tests/unit/core/rewrite/test_slices.py | 34 +++++ 3 files changed, 181 insertions(+), 21 deletions(-) rename tests/unit/core/{test_rewrite.py => rewrite/conftest.py} (56%) create mode 100644 tests/unit/core/rewrite/test_identifiers.py create mode 100644 tests/unit/core/rewrite/test_slices.py diff --git a/tests/unit/core/test_rewrite.py b/tests/unit/core/rewrite/conftest.py similarity index 56% rename from tests/unit/core/test_rewrite.py rename to tests/unit/core/rewrite/conftest.py index 1f1a2c3db9..22b897f3bf 100644 --- a/tests/unit/core/test_rewrite.py +++ b/tests/unit/core/rewrite/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +14,9 @@ import unittest.mock as mock import google.cloud.bigquery +import pytest import bigframes.core as core -import bigframes.core.nodes as nodes -import bigframes.core.rewrite.slices import bigframes.core.schema TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table") @@ -31,27 +30,22 @@ ) FAKE_SESSION = mock.create_autospec(bigframes.Session, instance=True) type(FAKE_SESSION)._strictly_ordered = mock.PropertyMock(return_value=True) -LEAF = core.ArrayValue.from_table( - session=FAKE_SESSION, - table=TABLE, - schema=bigframes.core.schema.ArraySchema.from_bq_table(TABLE), -).node -def test_rewrite_noop_slice(): - slice = nodes.SliceNode(LEAF, None, None) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert result == LEAF +@pytest.fixture +def table(): + return TABLE -def test_rewrite_reverse_slice(): - slice = nodes.SliceNode(LEAF, None, None, -1) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert result == nodes.ReversedNode(LEAF) +@pytest.fixture +def fake_session(): + return FAKE_SESSION -def test_rewrite_filter_slice(): - slice = nodes.SliceNode(LEAF, None, 2) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert list(result.fields) == list(LEAF.fields) - assert isinstance(result.child, nodes.FilterNode) +@pytest.fixture +def leaf(fake_session, table): + return core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=bigframes.core.schema.ArraySchema.from_bq_table(table), + ).node diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py new file mode 100644 index 0000000000..fd12df60a8 --- /dev/null +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -0,0 +1,132 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.core as core +import bigframes.core.identifiers as identifiers +import bigframes.core.nodes as nodes +import bigframes.core.rewrite.identifiers as id_rewrite + + +def test_remap_variables_single_node(leaf): + node = leaf + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node, mapping = id_rewrite.remap_variables(node, id_generator) + assert new_node is not node + assert len(mapping) == 2 + assert set(mapping.keys()) == {f.id for f in node.fields} + assert set(mapping.values()) == { + identifiers.ColumnId("id_0"), + identifiers.ColumnId("id_1"), + } + + +def test_remap_variables_projection(leaf): + node = nodes.ProjectionNode( + leaf, + ( + ( + core.expression.DerefOp(leaf.fields[0].id), + identifiers.ColumnId("new_col"), + ), + ), + ) + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node, mapping = id_rewrite.remap_variables(node, id_generator) + assert new_node is not node + assert len(mapping) == 3 + assert set(mapping.keys()) == {f.id for f in node.fields} + assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)} + + +def test_remap_variables_nested_join_stability(leaf, fake_session, table): + # Create two more distinct leaf nodes + leaf2_uncached = core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=leaf.schema, + ).node + leaf2 = leaf2_uncached.remap_vars( + { + field.id: identifiers.ColumnId(f"leaf2_{field.id.name}") + for field in leaf2_uncached.fields + } + ) + leaf3_uncached = core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=leaf.schema, + ).node + leaf3 = leaf3_uncached.remap_vars( + { + field.id: identifiers.ColumnId(f"leaf3_{field.id.name}") + for field in leaf3_uncached.fields + } + ) + + # Create a nested join: (leaf JOIN leaf2) JOIN leaf3 + inner_join = nodes.JoinNode( + left_child=leaf, + right_child=leaf2, + conditions=( + ( + core.expression.DerefOp(leaf.fields[0].id), + core.expression.DerefOp(leaf2.fields[0].id), + ), + ), + type="inner", + propogate_order=False, + ) + outer_join = nodes.JoinNode( + left_child=inner_join, + right_child=leaf3, + conditions=( + ( + core.expression.DerefOp(inner_join.fields[0].id), + core.expression.DerefOp(leaf3.fields[0].id), + ), + ), + type="inner", + propogate_order=False, + ) + + # Run remap_variables twice and assert stability + id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node1, mapping1 = id_rewrite.remap_variables(outer_join, id_generator1) + + id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node2, mapping2 = id_rewrite.remap_variables(outer_join, id_generator2) + + assert new_node1 == new_node2 + assert mapping1 == mapping2 + + +def test_remap_variables_concat_self_stability(leaf): + # Create a concat node with the same child twice + node = nodes.ConcatNode( + children=(leaf, leaf), + output_ids=( + identifiers.ColumnId("concat_a"), + identifiers.ColumnId("concat_b"), + ), + ) + + # Run remap_variables twice and assert stability + id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node1, mapping1 = id_rewrite.remap_variables(node, id_generator1) + + id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node2, mapping2 = id_rewrite.remap_variables(node, id_generator2) + + assert new_node1 == new_node2 + assert mapping1 == mapping2 diff --git a/tests/unit/core/rewrite/test_slices.py b/tests/unit/core/rewrite/test_slices.py new file mode 100644 index 0000000000..6d49ffb80a --- /dev/null +++ b/tests/unit/core/rewrite/test_slices.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import bigframes.core.nodes as nodes +import bigframes.core.rewrite.slices + + +def test_rewrite_noop_slice(leaf): + slice = nodes.SliceNode(leaf, None, None) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert result == leaf + + +def test_rewrite_reverse_slice(leaf): + slice = nodes.SliceNode(leaf, None, None, -1) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert result == nodes.ReversedNode(leaf) + + +def test_rewrite_filter_slice(leaf): + slice = nodes.SliceNode(leaf, None, 2) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert list(result.fields) == list(leaf.fields) + assert isinstance(result.child, nodes.FilterNode) From 39616374bba424996ebeb9a12096bfaf22660b44 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 28 Aug 2025 14:10:14 -0700 Subject: [PATCH 17/28] perf: improve iter_nodes_topo performance using Kahn's algorithm (#2038) --- bigframes/core/bigframe_node.py | 52 +++++++++++++-------------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/bigframes/core/bigframe_node.py b/bigframes/core/bigframe_node.py index 9054ab9ba0..0c6f56f35a 100644 --- a/bigframes/core/bigframe_node.py +++ b/bigframes/core/bigframe_node.py @@ -20,17 +20,7 @@ import functools import itertools import typing -from typing import ( - Callable, - Dict, - Generator, - Iterable, - Mapping, - Sequence, - Set, - Tuple, - Union, -) +from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union from bigframes.core import expression, field, identifiers import bigframes.core.schema as schemata @@ -309,33 +299,31 @@ def unique_nodes( seen.add(item) stack.extend(item.child_nodes) - def edges( + def iter_nodes_topo( self: BigFrameNode, - ) -> Generator[Tuple[BigFrameNode, BigFrameNode], None, None]: - for item in self.unique_nodes(): - for child in item.child_nodes: - yield (item, child) - - def iter_nodes_topo(self: BigFrameNode) -> Generator[BigFrameNode, None, None]: - """Returns nodes from bottom up.""" - queue = collections.deque( - [node for node in self.unique_nodes() if not node.child_nodes] - ) - + ) -> Generator[BigFrameNode, None, None]: + """Returns nodes in reverse topological order, using Kahn's algorithm.""" child_to_parents: Dict[ - BigFrameNode, Set[BigFrameNode] - ] = collections.defaultdict(set) - for parent, child in self.edges(): - child_to_parents[child].add(parent) - - yielded = set() + BigFrameNode, list[BigFrameNode] + ] = collections.defaultdict(list) + out_degree: Dict[BigFrameNode, int] = collections.defaultdict(int) + + queue: collections.deque["BigFrameNode"] = collections.deque() + for node in list(self.unique_nodes()): + num_children = len(node.child_nodes) + out_degree[node] = num_children + if num_children == 0: + queue.append(node) + for child in node.child_nodes: + child_to_parents[child].append(node) while queue: item = queue.popleft() yield item - yielded.add(item) - for parent in child_to_parents[item]: - if set(parent.child_nodes).issubset(yielded): + parents = child_to_parents.get(item, []) + for parent in parents: + out_degree[parent] -= 1 + if out_degree[parent] == 0: queue.append(parent) def top_down( From fbb209468297a8057d9d49c40e425c3bfdeb92bd Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Thu, 28 Aug 2025 15:24:57 -0700 Subject: [PATCH 18/28] perf: Improve axis=1 aggregation performance (#2036) --- bigframes/core/blocks.py | 44 ++------------ .../ibis_compiler/aggregate_compiler.py | 2 +- .../ibis_compiler/scalar_op_registry.py | 22 +++++++ bigframes/core/compile/polars/compiler.py | 31 ++++++++++ bigframes/operations/__init__.py | 10 +++- bigframes/operations/array_ops.py | 27 ++++++++- tests/system/small/engines/conftest.py | 7 +++ tests/system/small/engines/test_array_ops.py | 60 +++++++++++++++++++ .../sql/compilers/bigquery/__init__.py | 3 + .../ibis/expr/operations/arrays.py | 15 +++++ .../bigframes_vendored/ibis/expr/rewrites.py | 2 +- .../ibis/expr/types/arrays.py | 18 ++++++ .../ibis/expr/types/logical.py | 3 + 13 files changed, 200 insertions(+), 44 deletions(-) create mode 100644 tests/system/small/engines/test_array_ops.py diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 1a2544704c..283f56fd39 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1232,46 +1232,10 @@ def aggregate_all_and_stack( index_labels=[None], ).transpose(original_row_index=pd.Index([None]), single_row_mode=True) else: # axis_n == 1 - # using offsets as identity to group on. - # TODO: Allow to promote identity/total_order columns instead for better perf - expr_with_offsets, offset_col = self.expr.promote_offsets() - stacked_expr, (_, value_col_ids, passthrough_cols,) = unpivot( - expr_with_offsets, - row_labels=self.column_labels, - unpivot_columns=[tuple(self.value_columns)], - passthrough_columns=[*self.index_columns, offset_col], - ) - # these corresponed to passthrough_columns provided to unpivot - index_cols = passthrough_cols[:-1] - og_offset_col = passthrough_cols[-1] - index_aggregations = [ - ( - ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)), - col_id, - ) - for col_id in index_cols - ] - # TODO: may need add NullaryAggregation in main_aggregation - # when agg add support for axis=1, needed for agg("size", axis=1) - assert isinstance( - operation, agg_ops.UnaryAggregateOp - ), f"Expected a unary operation, but got {operation}. Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)." - main_aggregation = ( - ex.UnaryAggregation(operation, ex.deref(value_col_ids[0])), - value_col_ids[0], - ) - # Drop row identity after aggregating over it - result_expr = stacked_expr.aggregate( - [*index_aggregations, main_aggregation], - by_column_ids=[og_offset_col], - dropna=dropna, - ).drop_columns([og_offset_col]) - return Block( - result_expr, - index_columns=index_cols, - column_labels=[None], - index_labels=self.index.names, - ) + as_array = ops.ToArrayOp().as_expr(*(col for col in self.value_columns)) + reduced = ops.ArrayReduceOp(operation).as_expr(as_array) + block, id = self.project_expr(reduced, None) + return block.select_column(id) def aggregate_size( self, diff --git a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py index 4e0bf477fc..291db44524 100644 --- a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py +++ b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py @@ -165,7 +165,7 @@ def _( ) -> ibis_types.NumericValue: # Will be null if all inputs are null. Pandas defaults to zero sum though. bq_sum = _apply_window_if_present(column.sum(), window) - return bq_sum.fill_null(ibis_types.literal(0)) + return bq_sum.coalesce(ibis_types.literal(0)) @compile_unary_agg.register diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index f3653efc56..969ae2659d 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1201,6 +1201,28 @@ def array_slice_op_impl(x: ibis_types.Value, op: ops.ArraySliceOp): return res +@scalar_op_compiler.register_nary_op(ops.ToArrayOp, pass_op=False) +def to_arry_op_impl(*values: ibis_types.Value): + do_upcast_bool = any(t.type().is_numeric() for t in values) + if do_upcast_bool: + values = tuple( + val.cast(ibis_dtypes.int64) if val.type().is_boolean() else val + for val in values + ) + return ibis_api.array(values) + + +@scalar_op_compiler.register_unary_op(ops.ArrayReduceOp, pass_op=True) +def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp): + import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compilers + + return typing.cast(ibis_types.ArrayValue, x).reduce( + lambda arr_vals: agg_compilers.compile_unary_agg( + op.aggregation, typing.cast(ibis_types.Column, arr_vals) + ) + ) + + # JSON Ops @scalar_op_compiler.register_binary_op(ops.JSONSet, pass_op=True) def json_set_op_impl(x: ibis_types.Value, y: ibis_types.Value, op: ops.JSONSet): diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 1bfbe0f734..3316154de7 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -31,6 +31,7 @@ import bigframes.dtypes import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops +import bigframes.operations.array_ops as arr_ops import bigframes.operations.bool_ops as bool_ops import bigframes.operations.comparison_ops as comp_ops import bigframes.operations.datetime_ops as dt_ops @@ -353,6 +354,36 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, json_ops.JSONDecode) return input.str.json_decode(_DTYPE_MAPPING[op.to_type]) + @compile_op.register(arr_ops.ToArrayOp) + def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr: + return pl.concat_list(*inputs) + + @compile_op.register(arr_ops.ArrayReduceOp) + def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr: + # TODO: Unify this with general aggregation compilation? + if isinstance(op.aggregation, agg_ops.MinOp): + return input.list.min() + if isinstance(op.aggregation, agg_ops.MaxOp): + return input.list.max() + if isinstance(op.aggregation, agg_ops.SumOp): + return input.list.sum() + if isinstance(op.aggregation, agg_ops.MeanOp): + return input.list.mean() + if isinstance(op.aggregation, agg_ops.CountOp): + return input.list.len() + if isinstance(op.aggregation, agg_ops.StdOp): + return input.list.std() + if isinstance(op.aggregation, agg_ops.VarOp): + return input.list.var() + if isinstance(op.aggregation, agg_ops.AnyOp): + return input.list.any() + if isinstance(op.aggregation, agg_ops.AllOp): + return input.list.all() + else: + raise NotImplementedError( + f"Haven't implemented array aggregation: {op.aggregation}" + ) + @dataclasses.dataclass(frozen=True) class PolarsAggregateCompiler: scalar_compiler = PolarsExpressionCompiler() diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index e10a972790..e5888ace00 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -14,7 +14,13 @@ from __future__ import annotations -from bigframes.operations.array_ops import ArrayIndexOp, ArraySliceOp, ArrayToStringOp +from bigframes.operations.array_ops import ( + ArrayIndexOp, + ArrayReduceOp, + ArraySliceOp, + ArrayToStringOp, + ToArrayOp, +) from bigframes.operations.base_ops import ( BinaryOp, NaryOp, @@ -405,4 +411,6 @@ # Numpy ops mapping "NUMPY_TO_BINOP", "NUMPY_TO_OP", + "ToArrayOp", + "ArrayReduceOp", ] diff --git a/bigframes/operations/array_ops.py b/bigframes/operations/array_ops.py index c1e644fc11..61ada59cc7 100644 --- a/bigframes/operations/array_ops.py +++ b/bigframes/operations/array_ops.py @@ -13,10 +13,11 @@ # limitations under the License. import dataclasses +import functools import typing from bigframes import dtypes -from bigframes.operations import base_ops +from bigframes.operations import aggregations, base_ops @dataclasses.dataclass(frozen=True) @@ -63,3 +64,27 @@ def output_type(self, *input_types): return input_type else: raise TypeError("Input type must be an array or string-like type.") + + +class ToArrayOp(base_ops.NaryOp): + name: typing.ClassVar[str] = "array" + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + # very permissive, maybe should force caller to do this? + common_type = functools.reduce( + lambda t1, t2: dtypes.coerce_to_common(t1, t2), + input_types, + ) + return dtypes.list_type(common_type) + + +@dataclasses.dataclass(frozen=True) +class ArrayReduceOp(base_ops.UnaryOp): + name: typing.ClassVar[str] = "array_reduce" + aggregation: aggregations.AggregateOp + + def output_type(self, *input_types): + input_type = input_types[0] + assert dtypes.is_array_like(input_type) + inner_type = dtypes.get_array_inner_type(input_type) + return self.aggregation.output_type(inner_type) diff --git a/tests/system/small/engines/conftest.py b/tests/system/small/engines/conftest.py index 4f0f875b34..9699cc6a61 100644 --- a/tests/system/small/engines/conftest.py +++ b/tests/system/small/engines/conftest.py @@ -90,3 +90,10 @@ def repeated_data_source( repeated_pandas_df: pd.DataFrame, ) -> local_data.ManagedArrowTable: return local_data.ManagedArrowTable.from_pandas(repeated_pandas_df) + + +@pytest.fixture(scope="module") +def arrays_array_value( + repeated_data_source: local_data.ManagedArrowTable, fake_session: bigframes.Session +): + return ArrayValue.from_managed(repeated_data_source, fake_session) diff --git a/tests/system/small/engines/test_array_ops.py b/tests/system/small/engines/test_array_ops.py new file mode 100644 index 0000000000..c53b9e9dc1 --- /dev/null +++ b/tests/system/small/engines/test_array_ops.py @@ -0,0 +1,60 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value, expression +import bigframes.operations as ops +import bigframes.operations.aggregations as agg_ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine): + # Bigquery won't allow you to materialize arrays with null, so use non-nullable + int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0)) + bool_col_non_null = ops.coalesce_op.as_expr("bool_col", expression.const(False)) + float_col_non_null = ops.coalesce_op.as_expr("float64_col", expression.const(0.0)) + string_col_non_null = ops.coalesce_op.as_expr("string_col", expression.const("")) + + arr, _ = scalars_array_value.compute_values( + [ + ops.ToArrayOp().as_expr(int64_non_null), + ops.ToArrayOp().as_expr( + int64_non_null, bool_col_non_null, float_col_non_null + ), + ops.ToArrayOp().as_expr(string_col_non_null, string_col_non_null), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_array_reduce_op(arrays_array_value: array_value.ArrayValue, engine): + arr, _ = arrays_array_value.compute_values( + [ + ops.ArrayReduceOp(agg_ops.SumOp()).as_expr("float_list_col"), + ops.ArrayReduceOp(agg_ops.StdOp()).as_expr("float_list_col"), + ops.ArrayReduceOp(agg_ops.MaxOp()).as_expr("date_list_col"), + ops.ArrayReduceOp(agg_ops.CountOp()).as_expr("string_list_col"), + ops.ArrayReduceOp(agg_ops.AnyOp()).as_expr("bool_list_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 08bf0d7650..61bafeeca2 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -699,6 +699,9 @@ def visit_ArrayFilter(self, op, *, arg, body, param): def visit_ArrayMap(self, op, *, arg, body, param): return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param))) + def visit_ArrayReduce(self, op, *, arg, body, param): + return sg.select(body).from_(self._unnest(arg, as_=param)).subquery() + def visit_ArrayZip(self, op, *, arg): lengths = [self.f.array_length(arr) - 1 for arr in arg] idx = sg.to_identifier(util.gen_name("bq_arr_idx")) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py b/third_party/bigframes_vendored/ibis/expr/operations/arrays.py index 638b24a212..8134506255 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/arrays.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/arrays.py @@ -105,6 +105,21 @@ def dtype(self) -> dt.DataType: return dt.Array(self.body.dtype) +@public +class ArrayReduce(Value): + """Apply a function to every element of an array.""" + + arg: Value[dt.Array] + body: Value + param: str + + shape = rlz.shape_like("arg") + + @attribute + def dtype(self) -> dt.DataType: + return self.body.dtype + + @public class ArrayFilter(Value): """Filter array elements with a function.""" diff --git a/third_party/bigframes_vendored/ibis/expr/rewrites.py b/third_party/bigframes_vendored/ibis/expr/rewrites.py index a85498b30b..b0569846da 100644 --- a/third_party/bigframes_vendored/ibis/expr/rewrites.py +++ b/third_party/bigframes_vendored/ibis/expr/rewrites.py @@ -252,7 +252,7 @@ def rewrite_project_input(value, relation): # relation return value.replace( project_wrap_analytic | project_wrap_reduction, - filter=p.Value & ~p.WindowFunction, + filter=p.Value & ~p.WindowFunction & ~p.ArrayReduce, context={"rel": relation}, ) diff --git a/third_party/bigframes_vendored/ibis/expr/types/arrays.py b/third_party/bigframes_vendored/ibis/expr/types/arrays.py index a8f64490c1..72f01334c1 100644 --- a/third_party/bigframes_vendored/ibis/expr/types/arrays.py +++ b/third_party/bigframes_vendored/ibis/expr/types/arrays.py @@ -486,6 +486,24 @@ def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: body = resolve(parameter.to_expr()) return ops.ArrayMap(self, param=parameter.param, body=body).to_expr() + def reduce(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: + if isinstance(func, Deferred): + name = "_" + resolve = func.resolve + elif callable(func): + name = next(iter(inspect.signature(func).parameters.keys())) + resolve = func + else: + raise TypeError( + f"`func` must be a Deferred or Callable, got `{type(func).__name__}`" + ) + + parameter = ops.Argument( + name=name, shape=self.op().shape, dtype=self.type().value_type + ) + body = resolve(parameter.to_expr()) + return ops.ArrayReduce(self, param=parameter.param, body=body).to_expr() + def filter( self, predicate: Deferred | Callable[[ir.Value], bool | ir.BooleanValue] ) -> ir.ArrayValue: diff --git a/third_party/bigframes_vendored/ibis/expr/types/logical.py b/third_party/bigframes_vendored/ibis/expr/types/logical.py index 80a8527a04..cc86c747f6 100644 --- a/third_party/bigframes_vendored/ibis/expr/types/logical.py +++ b/third_party/bigframes_vendored/ibis/expr/types/logical.py @@ -353,6 +353,9 @@ def resolve_exists_subquery(outer): return Deferred(Call(resolve_exists_subquery, _)) elif len(parents) == 1: op = ops.Any(self, where=self._bind_to_parent_table(where)) + elif len(parents) == 0: + # array reduction case + op = ops.Any(self, where=self._bind_to_parent_table(where)) else: raise NotImplementedError( f'Cannot compute "any" for expression of type {type(self)} ' From 7ac6fe16f7f2c09d2efac6ab813ec841c21baef8 Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Fri, 29 Aug 2025 12:00:45 -0700 Subject: [PATCH 19/28] feat: Local date accessor execution support (#2034) --- bigframes/core/compile/polars/compiler.py | 53 +++++++++++++++ bigframes/operations/datetimes.py | 5 +- bigframes/operations/frequency_ops.py | 15 ++++- bigframes/session/polars_executor.py | 11 ++++ .../system/small/engines/test_temporal_ops.py | 66 +++++++++++++++++++ .../test_unary_compiler/test_floor_dt/out.sql | 2 +- .../expressions/test_unary_compiler.py | 2 +- 7 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 tests/system/small/engines/test_temporal_ops.py diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 3316154de7..70fa516e51 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -34,7 +34,9 @@ import bigframes.operations.array_ops as arr_ops import bigframes.operations.bool_ops as bool_ops import bigframes.operations.comparison_ops as comp_ops +import bigframes.operations.date_ops as date_ops import bigframes.operations.datetime_ops as dt_ops +import bigframes.operations.frequency_ops as freq_ops import bigframes.operations.generic_ops as gen_ops import bigframes.operations.json_ops as json_ops import bigframes.operations.numeric_ops as num_ops @@ -75,6 +77,20 @@ def decorator(func): if polars_installed: + _FREQ_MAPPING = { + "Y": "1y", + "Q": "1q", + "M": "1mo", + "W": "1w", + "D": "1d", + "h": "1h", + "min": "1m", + "s": "1s", + "ms": "1ms", + "us": "1us", + "ns": "1ns", + } + _DTYPE_MAPPING = { # Direct mappings bigframes.dtypes.INT_DTYPE: pl.Int64(), @@ -330,11 +346,48 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: else: return pl.any_horizontal(*(input.str.ends_with(pat) for pat in op.pat)) + @compile_op.register(freq_ops.FloorDtOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, freq_ops.FloorDtOp) + return input.dt.truncate(every=_FREQ_MAPPING[op.freq]) + @compile_op.register(dt_ops.StrftimeOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, dt_ops.StrftimeOp) return input.dt.strftime(op.date_format) + @compile_op.register(date_ops.YearOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.year() + + @compile_op.register(date_ops.QuarterOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.quarter() + + @compile_op.register(date_ops.MonthOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.month() + + @compile_op.register(date_ops.DayOfWeekOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.weekday() - 1 + + @compile_op.register(date_ops.DayOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.day() + + @compile_op.register(date_ops.IsoYearOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.iso_year() + + @compile_op.register(date_ops.IsoWeekOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.week() + + @compile_op.register(date_ops.IsoDayOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + return input.dt.weekday() + @compile_op.register(dt_ops.ParseDatetimeOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, dt_ops.ParseDatetimeOp) diff --git a/bigframes/operations/datetimes.py b/bigframes/operations/datetimes.py index 14bf10f463..95896ddc97 100644 --- a/bigframes/operations/datetimes.py +++ b/bigframes/operations/datetimes.py @@ -30,6 +30,7 @@ _ONE_DAY = pandas.Timedelta("1d") _ONE_SECOND = pandas.Timedelta("1s") _ONE_MICRO = pandas.Timedelta("1us") +_SUPPORTED_FREQS = ("Y", "Q", "M", "W", "D", "h", "min", "s", "ms", "us") @log_adapter.class_logger @@ -155,4 +156,6 @@ def normalize(self) -> series.Series: return self._apply_unary_op(ops.normalize_op) def floor(self, freq: str) -> series.Series: - return self._apply_unary_op(ops.FloorDtOp(freq=freq)) + if freq not in _SUPPORTED_FREQS: + raise ValueError(f"freq must be one of {_SUPPORTED_FREQS}") + return self._apply_unary_op(ops.FloorDtOp(freq=freq)) # type: ignore diff --git a/bigframes/operations/frequency_ops.py b/bigframes/operations/frequency_ops.py index 2d5a854c32..b94afa7271 100644 --- a/bigframes/operations/frequency_ops.py +++ b/bigframes/operations/frequency_ops.py @@ -27,9 +27,22 @@ @dataclasses.dataclass(frozen=True) class FloorDtOp(base_ops.UnaryOp): name: typing.ClassVar[str] = "floor_dt" - freq: str + freq: typing.Literal[ + "Y", + "Q", + "M", + "W", + "D", + "h", + "min", + "s", + "ms", + "us", + ] def output_type(self, *input_types): + if not dtypes.is_datetime_like(input_types[0]): + raise TypeError("dt floor requires datetime-like arguments") return input_types[0] diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index b93d31d255..d8df558fe4 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -24,6 +24,8 @@ from bigframes.operations import ( bool_ops, comparison_ops, + date_ops, + frequency_ops, generic_ops, numeric_ops, string_ops, @@ -60,6 +62,15 @@ comparison_ops.GtOp, comparison_ops.LeOp, comparison_ops.GeOp, + date_ops.YearOp, + date_ops.QuarterOp, + date_ops.MonthOp, + date_ops.DayOfWeekOp, + date_ops.DayOp, + date_ops.IsoYearOp, + date_ops.IsoWeekOp, + date_ops.IsoDayOp, + frequency_ops.FloorDtOp, numeric_ops.AddOp, numeric_ops.SubOp, numeric_ops.MulOp, diff --git a/tests/system/small/engines/test_temporal_ops.py b/tests/system/small/engines/test_temporal_ops.py new file mode 100644 index 0000000000..5a39587886 --- /dev/null +++ b/tests/system/small/engines/test_temporal_ops.py @@ -0,0 +1,66 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import array_value +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_dt_floor(scalars_array_value: array_value.ArrayValue, engine): + arr, _ = scalars_array_value.compute_values( + [ + ops.FloorDtOp("us").as_expr("timestamp_col"), + ops.FloorDtOp("ms").as_expr("timestamp_col"), + ops.FloorDtOp("s").as_expr("timestamp_col"), + ops.FloorDtOp("min").as_expr("timestamp_col"), + ops.FloorDtOp("h").as_expr("timestamp_col"), + ops.FloorDtOp("D").as_expr("timestamp_col"), + ops.FloorDtOp("W").as_expr("timestamp_col"), + ops.FloorDtOp("M").as_expr("timestamp_col"), + ops.FloorDtOp("Q").as_expr("timestamp_col"), + ops.FloorDtOp("Y").as_expr("timestamp_col"), + ops.FloorDtOp("Q").as_expr("datetime_col"), + ops.FloorDtOp("us").as_expr("datetime_col"), + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_date_accessors(scalars_array_value: array_value.ArrayValue, engine): + datelike_cols = ["datetime_col", "timestamp_col", "date_col"] + accessors = [ + ops.day_op, + ops.dayofweek_op, + ops.month_op, + ops.quarter_op, + ops.year_op, + ops.iso_day_op, + ops.iso_week_op, + ops.iso_year_op, + ] + + exprs = [acc.as_expr(col) for acc in accessors for col in datelike_cols] + + arr, _ = scalars_array_value.compute_values(exprs) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_floor_dt/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_floor_dt/out.sql index 3c7efd3098..ad4fdb23a1 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_floor_dt/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_floor_dt/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - TIMESTAMP_TRUNC(`bfcol_0`, DAY) AS `bfcol_1` + TIMESTAMP_TRUNC(`bfcol_0`, D) AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index f011721ee5..8f3af11842 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -153,7 +153,7 @@ def test_expm1(scalar_types_df: bpd.DataFrame, snapshot): def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.FloorDtOp("DAY"), "timestamp_col") + sql = _apply_unary_op(bf_df, ops.FloorDtOp("D"), "timestamp_col") snapshot.assert_match(sql, "out.sql") From 70726270e580977ad4e1750d8e0cc2c6c1338ce5 Mon Sep 17 00:00:00 2001 From: jialuoo Date: Fri, 29 Aug 2025 13:38:01 -0700 Subject: [PATCH 20/28] test: Add unit test for get_python_version (#2041) * test: Add unit test for get_python_version * fix --- .../functions/test_remote_function_utils.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index 8ddd39d857..812d65bbad 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import sys from unittest.mock import patch import bigframes_vendored.constants as constants @@ -227,6 +228,26 @@ def test_get_updated_package_requirements_with_existing_cloudpickle(): assert result == expected +# Dynamically generate expected python versions for the test +_major = sys.version_info.major +_minor = sys.version_info.minor +_compat_version = f"python{_major}{_minor}" +_standard_version = f"python-{_major}.{_minor}" + + +@pytest.mark.parametrize( + "is_compat, expected_version", + [ + (True, _compat_version), + (False, _standard_version), + ], +) +def test_get_python_version(is_compat, expected_version): + """Tests the python version for both standard and compat modes.""" + result = _utils.get_python_version(is_compat=is_compat) + assert result == expected_version + + def test_package_existed_helper(): """Tests the _package_existed helper function directly.""" reqs = ["pandas==1.0", "numpy", "scikit-learn>=1.2.0"] From 164c4818bc4ff2990dca16b9f22a798f47e0a60b Mon Sep 17 00:00:00 2001 From: jialuoo Date: Fri, 29 Aug 2025 14:04:16 -0700 Subject: [PATCH 21/28] feat: Support args in dataframe apply method (#2026) * feat: Allow passing args to managed functions in DataFrame apply method * remove a test * support remote function * resolve the comments * improve the message * fix the tests --- bigframes/dataframe.py | 73 +++++++--- bigframes/functions/_function_session.py | 13 +- bigframes/functions/function.py | 7 - bigframes/functions/function_template.py | 26 +++- .../large/functions/test_managed_function.py | 117 +++++++++++++++- .../large/functions/test_remote_function.py | 126 ++++++++++++++++-- .../small/functions/test_remote_function.py | 14 -- 7 files changed, 320 insertions(+), 56 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 85760d94bc..d618d13aa4 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -77,6 +77,7 @@ import bigframes.exceptions as bfe import bigframes.formatting_helpers as formatter import bigframes.functions +from bigframes.functions import function_typing import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops import bigframes.operations.ai @@ -4835,37 +4836,73 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs): ) # Apply the function - result_series = rows_as_json_series._apply_unary_op( - ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) - ) + if args: + result_series = rows_as_json_series._apply_nary_op( + ops.NaryRemoteFunctionOp(function_def=func.udf_def), + list(args), + ) + else: + result_series = rows_as_json_series._apply_unary_op( + ops.RemoteFunctionOp( + function_def=func.udf_def, apply_on_null=True + ) + ) else: # This is a special case where we are providing not-pandas-like # extension. If the bigquery function can take one or more - # params then we assume that here the user intention is to use - # the column values of the dataframe as arguments to the - # function. For this to work the following condition must be - # true: - # 1. The number or input params in the function must be same - # as the number of columns in the dataframe + # params (excluding the args) then we assume that here the user + # intention is to use the column values of the dataframe as + # arguments to the function. For this to work the following + # condition must be true: + # 1. The number or input params (excluding the args) in the + # function must be same as the number of columns in the + # dataframe. # 2. The dtypes of the columns in the dataframe must be - # compatible with the data types of the input params + # compatible with the data types of the input params. # 3. The order of the columns in the dataframe must correspond - # to the order of the input params in the function + # to the order of the input params in the function. udf_input_dtypes = func.udf_def.signature.bf_input_types - if len(udf_input_dtypes) != len(self.columns): + if not args and len(udf_input_dtypes) != len(self.columns): raise ValueError( - f"BigFrames BigQuery function takes {len(udf_input_dtypes)}" - f" arguments but DataFrame has {len(self.columns)} columns." + f"Parameter count mismatch: BigFrames BigQuery function" + f" expected {len(udf_input_dtypes)} parameters but" + f" received {len(self.columns)} DataFrame columns." ) - if udf_input_dtypes != tuple(self.dtypes.to_list()): + if args and len(udf_input_dtypes) != len(self.columns) + len(args): raise ValueError( - f"BigFrames BigQuery function takes arguments of types " - f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}." + f"Parameter count mismatch: BigFrames BigQuery function" + f" expected {len(udf_input_dtypes)} parameters but" + f" received {len(self.columns) + len(args)} values" + f" ({len(self.columns)} DataFrame columns and" + f" {len(args)} args)." ) + end_slice = -len(args) if args else None + if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()): + raise ValueError( + f"Data type mismatch for DataFrame columns:" + f" Expected {udf_input_dtypes[:end_slice]}" + f" Received {tuple(self.dtypes)}." + ) + if args: + bq_types = ( + function_typing.sdk_type_from_python_type(type(arg)) + for arg in args + ) + args_dtype = tuple( + function_typing.sdk_type_to_bf_type(bq_type) + for bq_type in bq_types + ) + if udf_input_dtypes[end_slice:] != args_dtype: + raise ValueError( + f"Data type mismatch for 'args' parameter:" + f" Expected {udf_input_dtypes[end_slice:]}" + f" Received {args_dtype}." + ) series_list = [self[col] for col in self.columns] + op_list = series_list[1:] + list(args) result_series = series_list[0]._apply_nary_op( - ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:] + ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list ) result_series.name = None diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 90bfb89c56..a2fb66539b 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -959,11 +959,16 @@ def _convert_row_processor_sig( ) -> Optional[inspect.Signature]: import bigframes.series as bf_series - if len(signature.parameters) == 1: - only_param = next(iter(signature.parameters.values())) - param_type = only_param.annotation + if len(signature.parameters) >= 1: + first_param = next(iter(signature.parameters.values())) + param_type = first_param.annotation if (param_type == bf_series.Series) or (param_type == pandas.Series): msg = bfe.format_message("input_types=Series is in preview.") warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning) - return signature.replace(parameters=[only_param.replace(annotation=str)]) + return signature.replace( + parameters=[ + p.replace(annotation=str) if i == 0 else p + for i, p in enumerate(signature.parameters.values()) + ] + ) return None diff --git a/bigframes/functions/function.py b/bigframes/functions/function.py index a62da57075..99b89131e7 100644 --- a/bigframes/functions/function.py +++ b/bigframes/functions/function.py @@ -178,13 +178,6 @@ def read_gbq_function( ValueError, f"Unknown function '{routine_ref}'." ) - if is_row_processor and len(routine.arguments) > 1: - raise bf_formatting.create_exception_with_feedback_link( - ValueError, - "A multi-input function cannot be a row processor. A row processor function " - "takes in a single input representing the row.", - ) - if is_row_processor: return _try_import_row_routine(routine, session) else: diff --git a/bigframes/functions/function_template.py b/bigframes/functions/function_template.py index 5f04fcc8e2..dd31de7243 100644 --- a/bigframes/functions/function_template.py +++ b/bigframes/functions/function_template.py @@ -195,7 +195,9 @@ def udf_http_row_processor(request): calls = request_json["calls"] replies = [] for call in calls: - reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0]))) + reply = convert_to_bq_json( + output_type, udf(get_pd_series(call[0]), *call[1:]) + ) if type(reply) is list: # Since the BQ remote function does not support array yet, # return a json serialized version of the reply. @@ -332,6 +334,28 @@ def generate_managed_function_code( f"""def bigframes_handler(str_arg): return {udf_name}({get_pd_series.__name__}(str_arg))""" ) + + sig = inspect.signature(def_) + params = list(sig.parameters.values()) + additional_params = params[1:] + + # Build the parameter list for the new handler function definition. + # e.g., "str_arg, y: bool, z" + handler_def_parts = ["str_arg"] + handler_def_parts.extend(str(p) for p in additional_params) + handler_def_str = ", ".join(handler_def_parts) + + # Build the argument list for the call to the original UDF. + # e.g., "get_pd_series(str_arg), y, z" + udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"] + udf_call_parts.extend(p.name for p in additional_params) + udf_call_str = ", ".join(udf_call_parts) + + bigframes_handler_code = textwrap.dedent( + f"""def bigframes_handler({handler_def_str}): + return {udf_name}({udf_call_str})""" + ) + else: udf_code = "" bigframes_handler_code = textwrap.dedent( diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 73335afa3c..b0e44b648f 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -468,20 +468,20 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns. with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes. with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -965,6 +965,117 @@ def float_parser(row): ) +def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs): + columns = ["int64_col", "int64_too"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def the_sum(s1, s2, x): + return s1 + s2 + x + + the_sum_mf = session.udf( + input_types=[int, int, int], + output_type=int, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(the_sum) + + args1 = (1,) + + # Fails to apply on dataframe with incompatible number of columns and args. + with pytest.raises( + ValueError, + match="^Parameter count mismatch:.* expected 3 parameters but received 4 values \\(3 DataFrame columns and 1 args\\)", + ): + scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible column datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", + ): + scalars_df[columns].assign( + int64_col=lambda df: df["int64_col"].astype("Float64") + ).apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible args datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for 'args' parameter: Expected .* Received .*", + ): + scalars_df[columns].apply(the_sum_mf, axis=1, args=(1.3,)) + + bf_result = ( + scalars_df[columns] + .dropna() + .apply(the_sum_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the managed function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + +def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs): + columns = ["int64_col", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def analyze(s, x, y): + value = f"value is {s['int64_col']} and {s['float64_col']}" + if x: + return f"{value}, x is True!" + if y > 0: + return f"{value}, x is False, y is positive!" + return f"{value}, x is False, y is non-positive!" + + analyze_mf = session.udf( + input_types=[bigframes.series.Series, bool, float], + output_type=str, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(analyze) + + args1 = (True, 10.0) + bf_result = ( + scalars_df[columns] + .dropna() + .apply(analyze_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = ( + scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1) + ) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + args2 = (False, -10.0) + analyze_mf_ref = session.read_gbq_function( + analyze_mf.bigframes_bigquery_function, is_row_processor=True + ) + bf_result = ( + scalars_df[columns] + .dropna() + .apply(analyze_mf_ref, axis=1, args=args2) + .to_pandas() + ) + pd_result = ( + scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2) + ) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the managed function. + cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False) + + def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 3c453a52a4..e6372d768b 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -1937,6 +1937,114 @@ def float_parser(row): ) +@pytest.mark.flaky(retries=2, delay=120) +def test_df_apply_axis_1_args(session, scalars_dfs): + columns = ["int64_col", "int64_too"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def the_sum(s1, s2, x): + return s1 + s2 + x + + the_sum_mf = session.remote_function( + input_types=[int, int, int], + output_type=int, + reuse=False, + cloud_function_service_account="default", + )(the_sum) + + args1 = (1,) + + # Fails to apply on dataframe with incompatible number of columns and args. + with pytest.raises( + ValueError, + match="^Parameter count mismatch:.* expected 3 parameters but received 4 values \\(2 DataFrame columns and 2 args\\)", + ): + scalars_df[columns].apply( + the_sum_mf, + axis=1, + args=( + 1, + 1, + ), + ) + + # Fails to apply on dataframe with incompatible column datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", + ): + scalars_df[columns].assign( + int64_col=lambda df: df["int64_col"].astype("Float64") + ).apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible args datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for 'args' parameter: Expected .* Received .*", + ): + scalars_df[columns].apply(the_sum_mf, axis=1, args=("hello world",)) + + bf_result = ( + scalars_df[columns] + .dropna() + .apply(the_sum_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the remote function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_df_apply_axis_1_series_args(session, scalars_dfs): + columns = ["int64_col", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + @session.remote_function( + input_types=[bigframes.series.Series, float, str, bool], + output_type=list[str], + reuse=False, + cloud_function_service_account="default", + ) + def foo_list(x, y0: float, y1, y2) -> list[str]: + return ( + [str(x["int64_col"]), str(y0), str(y1), str(y2)] + if y2 + else [str(x["float64_col"])] + ) + + args1 = (12.34, "hello world", True) + bf_result = scalars_df[columns].apply(foo_list, axis=1, args=args1).to_pandas() + pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args1) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + args2 = (43.21, "xxx3yyy", False) + foo_list_ref = session.read_gbq_function( + foo_list.bigframes_bigquery_function, is_row_processor=True + ) + bf_result = ( + scalars_df[columns].apply(foo_list_ref, axis=1, args=args2).to_pandas() + ) + pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args2) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) + + @pytest.mark.parametrize( ("memory_mib_args", "expected_memory"), [ @@ -2200,19 +2308,19 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -2284,19 +2392,19 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -2358,19 +2466,19 @@ def foo(x): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 0 columns\\.$", + match="^Parameter count mismatch:.* expected 1 parameters but received 0 DataFrame.*", ): bf_df[[]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 1 parameters but received 2 DataFrame.*", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Id=bf_df["Id"].astype("Float64")).apply(foo, axis=1) diff --git a/tests/system/small/functions/test_remote_function.py b/tests/system/small/functions/test_remote_function.py index 86076e764f..28fab19144 100644 --- a/tests/system/small/functions/test_remote_function.py +++ b/tests/system/small/functions/test_remote_function.py @@ -1154,20 +1154,6 @@ def test_df_apply_scalar_func(session, scalars_dfs): ) -def test_read_gbq_function_multiple_inputs_not_a_row_processor(session): - with pytest.raises(ValueError) as context: - # The remote function has two args, which cannot be row processed. Throw - # a ValueError for it. - session.read_gbq_function( - function_name="bqutil.fn.cw_regexp_instr_2", - is_row_processor=True, - ) - assert str(context.value) == ( - "A multi-input function cannot be a row processor. A row processor function " - f"takes in a single input representing the row. {constants.FEEDBACK_LINK}" - ) - - @pytest.mark.flaky(retries=2, delay=120) def test_df_apply_axis_1(session, scalars_dfs, dataset_id_permanent): columns = [ From 1a0f710ac11418fd71ab3373f3f6002fa581b180 Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Fri, 29 Aug 2025 14:12:13 -0700 Subject: [PATCH 22/28] feat: Can pivot unordered, unindexed dataframe (#2040) --- bigframes/core/blocks.py | 14 +++++++++++--- bigframes/dataframe.py | 4 ---- tests/system/conftest.py | 12 ++++++++++++ tests/system/small/test_unordered.py | 24 ++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 283f56fd39..f7d456bf9d 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2129,9 +2129,17 @@ def _get_unique_values( import bigframes.core.block_transforms as block_tf import bigframes.dataframe as df - unique_value_block = block_tf.drop_duplicates( - self.select_columns(columns), columns - ) + if self.explicitly_ordered: + unique_value_block = block_tf.drop_duplicates( + self.select_columns(columns), columns + ) + else: + unique_value_block, _ = self.aggregate(by_column_ids=columns, dropna=False) + col_labels = self._get_labels_for_columns(columns) + unique_value_block = unique_value_block.reset_index( + drop=False + ).with_column_labels(col_labels) + pd_values = ( df.DataFrame(unique_value_block).head(max_unique_values + 1).to_pandas() ) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index d618d13aa4..7f3d51a03e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3347,8 +3347,6 @@ def _pivot( ) return DataFrame(pivot_block) - @validations.requires_index - @validations.requires_ordering() def pivot( self, *, @@ -3362,8 +3360,6 @@ def pivot( ) -> DataFrame: return self._pivot(columns=columns, index=index, values=values) - @validations.requires_index - @validations.requires_ordering() def pivot_table( self, values: typing.Optional[ diff --git a/tests/system/conftest.py b/tests/system/conftest.py index a75918ed23..70a379fe0e 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -585,6 +585,18 @@ def scalars_df_null_index( ).sort_values("rowindex") +@pytest.fixture(scope="session") +def scalars_df_unordered( + scalars_table_id: str, unordered_session: bigframes.Session +) -> bigframes.dataframe.DataFrame: + """DataFrame pointing at test data.""" + df = unordered_session.read_gbq( + scalars_table_id, index_col=bigframes.enums.DefaultIndexKind.NULL + ) + assert not df._block.explicitly_ordered + return df + + @pytest.fixture(scope="session") def scalars_df_2_default_index( scalars_df_2_index: bigframes.dataframe.DataFrame, diff --git a/tests/system/small/test_unordered.py b/tests/system/small/test_unordered.py index 0825b78037..ccb2140799 100644 --- a/tests/system/small/test_unordered.py +++ b/tests/system/small/test_unordered.py @@ -265,3 +265,27 @@ def test__resample_with_index(unordered_session, rule, origin, data): pd.testing.assert_frame_equal( bf_result, pd_result, check_dtype=False, check_index_type=False ) + + +@pytest.mark.parametrize( + ("values", "index", "columns"), + [ + ("int64_col", "int64_too", ["string_col"]), + (["int64_col"], "int64_too", ["string_col"]), + (["int64_col", "float64_col"], "int64_too", ["string_col"]), + ], +) +def test_unordered_df_pivot( + scalars_df_unordered, scalars_pandas_df_index, values, index, columns +): + bf_result = scalars_df_unordered.pivot( + values=values, index=index, columns=columns + ).to_pandas() + pd_result = scalars_pandas_df_index.pivot( + values=values, index=index, columns=columns + ) + + # Pandas produces NaN, where bq dataframes produces pd.NA + bf_result = bf_result.fillna(float("nan")) + pd_result = pd_result.fillna(float("nan")) + pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) From 8689199aa82212ed300fff592097093812e0290e Mon Sep 17 00:00:00 2001 From: jialuoo Date: Fri, 29 Aug 2025 15:40:54 -0700 Subject: [PATCH 23/28] fix: Resolve the validation issue for other arg in dataframe where method (#2042) --- bigframes/dataframe.py | 6 ++-- .../large/functions/test_managed_function.py | 31 ++++++++++++++++++ .../large/functions/test_remote_function.py | 32 +++++++++++++++++++ tests/system/small/test_dataframe.py | 12 +++++++ 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 7f3d51a03e..a5ecd82d47 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2877,9 +2877,6 @@ def _apply_callable(self, condition): return condition def where(self, cond, other=None): - if isinstance(other, bigframes.series.Series): - raise ValueError("Seires is not a supported replacement type!") - if self.columns.nlevels > 1: raise NotImplementedError( "The dataframe.where() method does not support multi-column." @@ -2890,6 +2887,9 @@ def where(self, cond, other=None): cond = self._apply_callable(cond) other = self._apply_callable(other) + if isinstance(other, bigframes.series.Series): + raise ValueError("Seires is not a supported replacement type!") + aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. if isinstance(other, bigframes.dataframe.DataFrame): diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index b0e44b648f..0a04480a78 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1214,6 +1214,37 @@ def func_for_other(x): ) +def test_managed_function_df_where_other_issue(session, dataset_id, scalars_df_index): + try: + + def the_sum(s): + return s["int64_col"] + s["int64_too"] + + the_sum_mf = session.udf( + input_types=bigframes.series.Series, + output_type=int, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(the_sum) + + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df_index[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + # The execution of the callable other=the_sum_mf will return a + # Series, which is not a supported replacement type. + bf_int64_df_filtered.where(cond=bf_int64_df_filtered, other=the_sum_mf) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index e6372d768b..f60786437f 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -3004,6 +3004,38 @@ def is_sum_positive(a, b): ) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_df_where_other_issue(session, dataset_id, scalars_df_index): + try: + + def the_sum(a, b): + return a + b + + the_sum_mf = session.remote_function( + input_types=[int, float], + output_type=float, + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + )(the_sum) + + int64_cols = ["int64_col", "float64_col"] + bf_int64_df = scalars_df_index[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + # The execution of the callable other=the_sum_mf will return a + # Series, which is not a supported replacement type. + bf_int64_df_filtered.where(cond=bf_int64_df > 100, other=the_sum_mf) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + @pytest.mark.flaky(retries=2, delay=120) def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index c7f9627531..dce0a649f6 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -570,6 +570,18 @@ def func(x): pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_where_series_other(scalars_df_index): + # When other is a series, throw an error. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + dataframe_bf.where(dataframe_bf > 0, dataframe_bf["int64_col"]) + + def test_drop_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col" From a7963fe57a0e141debf726f0bc7b0e953ebe9634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Tue, 2 Sep 2025 18:32:07 +0000 Subject: [PATCH 24/28] feat!: add `allow_large_results` option to `read_gbq_query`, aligning with `bpd.options.compute.allow_large_results` option (#1935) Release-As: 2.18.0 --- bigframes/bigquery/_operations/search.py | 17 ++- bigframes/dataframe.py | 2 +- bigframes/ml/core.py | 92 ++++++++++-- bigframes/operations/ai.py | 4 + bigframes/pandas/io/api.py | 28 +++- bigframes/session/__init__.py | 125 +++++++++++++--- .../session/_io/bigquery/read_gbq_query.py | 57 +++++-- bigframes/session/loader.py | 21 ++- .../small/bigquery/test_vector_search.py | 141 ++++++++---------- tests/system/small/ml/test_forecasting.py | 42 ++++-- tests/system/small/ml/test_preprocessing.py | 2 +- .../small/session/test_read_gbq_query.py | 113 ++++++++++++++ tests/system/small/test_pandas_options.py | 20 +-- tests/system/small/test_session.py | 2 +- tests/system/small/test_unordered.py | 2 +- tests/unit/ml/test_golden_sql.py | 31 ++-- tests/unit/session/test_read_gbq_query.py | 2 +- .../bigframes_vendored/pandas/io/gbq.py | 6 + 18 files changed, 529 insertions(+), 178 deletions(-) create mode 100644 tests/system/small/session/test_read_gbq_query.py diff --git a/bigframes/bigquery/_operations/search.py b/bigframes/bigquery/_operations/search.py index 9a1e4b5ac9..5063fc9118 100644 --- a/bigframes/bigquery/_operations/search.py +++ b/bigframes/bigquery/_operations/search.py @@ -99,6 +99,7 @@ def vector_search( distance_type: Optional[Literal["euclidean", "cosine", "dot_product"]] = None, fraction_lists_to_search: Optional[float] = None, use_brute_force: Optional[bool] = None, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame: """ Conduct vector search which searches embeddings to find semantically similar entities. @@ -163,12 +164,12 @@ def vector_search( ... query=search_query, ... distance_type="cosine", ... query_column_to_search="another_embedding", - ... top_k=2) + ... top_k=2).sort_values("id") query_id embedding another_embedding id my_embedding distance - 1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181 - 0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013 1 cat [3. 5.2] [3.3 5.2] 1 [1. 2.] 0.005181 + 1 cat [3. 5.2] [3.3 5.2] 2 [2. 4.] 0.005181 0 dog [1. 2.] [0.7 2.2] 3 [1.5 7. ] 0.004697 + 0 dog [1. 2.] [0.7 2.2] 4 [1. 3.2] 0.000013 [4 rows x 6 columns] @@ -199,6 +200,10 @@ def vector_search( use_brute_force (bool): Determines whether to use brute force search by skipping the vector index if one is available. Default to False. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. + Defaults to ``bpd.options.compute.allow_large_results``. Returns: bigframes.dataframe.DataFrame: A DataFrame containing vector search result. @@ -236,9 +241,11 @@ def vector_search( options=options, ) if index_col_ids is not None: - df = query._session.read_gbq(sql, index_col=index_col_ids) + df = query._session.read_gbq_query( + sql, index_col=index_col_ids, allow_large_results=allow_large_results + ) df.index.names = index_labels else: - df = query._session.read_gbq(sql) + df = query._session.read_gbq_query(sql, allow_large_results=allow_large_results) return df diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index a5ecd82d47..75be1c256e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -4496,7 +4496,7 @@ def to_dict( allow_large_results: Optional[bool] = None, **kwargs, ) -> dict | list[dict]: - return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient, into, **kwargs) # type: ignore + return self.to_pandas(allow_large_results=allow_large_results).to_dict(orient=orient, into=into, **kwargs) # type: ignore def to_excel( self, diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 73b8ba8dbc..28f795a0b6 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -45,7 +45,11 @@ def ai_forecast( result_sql = self._sql_generator.ai_forecast( source_sql=input_data.sql, options=options ) - return self._session.read_gbq(result_sql) + + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(result_sql, allow_large_results=True) class BqmlModel(BaseBqml): @@ -95,7 +99,17 @@ def _apply_ml_tvf( ) result_sql = apply_sql_tvf(input_sql) - df = self._session.read_gbq(result_sql, index_col=index_col_ids) + df = self._session.read_gbq_query( + result_sql, + index_col=index_col_ids, + # Many ML methods use nested JSON, which isn't yet compatible with + # joining local results. Also, there is a chance that the results + # are greater than 10 GB. + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + allow_large_results=True, + ) if df._has_index: df.index.names = index_labels # Restore column labels @@ -159,7 +173,10 @@ def explain_predict( def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame: sql = self._sql_generator.ml_global_explain(struct_options=options) return ( - self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + self._session.read_gbq_query(sql, allow_large_results=True) .sort_values(by="attribution", ascending=False) .set_index("feature") ) @@ -234,26 +251,49 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._sql_generator.ml_forecast(struct_options=options) timestamp_col_name = "forecast_timestamp" index_cols = [timestamp_col_name] - first_col_name = self._session.read_gbq(sql).columns.values[0] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + first_col_name = self._session.read_gbq_query( + sql, allow_large_results=True + ).columns.values[0] if timestamp_col_name != first_col_name: index_cols.append(first_col_name) - return self._session.read_gbq(sql, index_col=index_cols).reset_index() + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=index_cols, allow_large_results=True + ).reset_index() def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame: sql = self._sql_generator.ml_explain_forecast(struct_options=options) timestamp_col_name = "time_series_timestamp" index_cols = [timestamp_col_name] - first_col_name = self._session.read_gbq(sql).columns.values[0] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + first_col_name = self._session.read_gbq_query( + sql, allow_large_results=True + ).columns.values[0] if timestamp_col_name != first_col_name: index_cols.append(first_col_name) - return self._session.read_gbq(sql, index_col=index_cols).reset_index() + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=index_cols, allow_large_results=True + ).reset_index() def evaluate(self, input_data: Optional[bpd.DataFrame] = None): sql = self._sql_generator.ml_evaluate( input_data.sql if (input_data is not None) else None ) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def llm_evaluate( self, @@ -262,25 +302,37 @@ def llm_evaluate( ): sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def arima_evaluate(self, show_all_candidate_models: bool = False): sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models) - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def arima_coefficients(self) -> bpd.DataFrame: sql = self._sql_generator.ml_arima_coefficients() - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def centroids(self) -> bpd.DataFrame: assert self._model.model_type == "KMEANS" sql = self._sql_generator.ml_centroids() - return self._session.read_gbq( - sql, index_col=["centroid_id", "feature"] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, index_col=["centroid_id", "feature"], allow_large_results=True ).reset_index() def principal_components(self) -> bpd.DataFrame: @@ -288,8 +340,13 @@ def principal_components(self) -> bpd.DataFrame: sql = self._sql_generator.ml_principal_components() - return self._session.read_gbq( - sql, index_col=["principal_component_id", "feature"] + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query( + sql, + index_col=["principal_component_id", "feature"], + allow_large_results=True, ).reset_index() def principal_component_info(self) -> bpd.DataFrame: @@ -297,7 +354,10 @@ def principal_component_info(self) -> bpd.DataFrame: sql = self._sql_generator.ml_principal_component_info() - return self._session.read_gbq(sql) + # TODO(b/395912450): Once the limitations with local data are + # resolved, consider setting allow_large_results only when expected + # data size is large. + return self._session.read_gbq_query(sql, allow_large_results=True) def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: job_config = self._session._prepare_copy_job_config() diff --git a/bigframes/operations/ai.py b/bigframes/operations/ai.py index 8c7628059a..ac294b0fbd 100644 --- a/bigframes/operations/ai.py +++ b/bigframes/operations/ai.py @@ -566,6 +566,10 @@ def search( column_to_search=embedding_result_column, query=query_df, top_k=top_k, + # TODO(tswast): set allow_large_results based on Series size. + # If we expect small results, it could be faster to set + # allow_large_results to False. + allow_large_results=True, ) .rename(columns={"content": search_column}) .set_index("index") diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index cf4b4eb19c..483bc5e530 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -187,6 +187,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., + allow_large_results: Optional[bool] = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -203,6 +204,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -218,6 +220,7 @@ def read_gbq( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), dry_run: bool = False, + allow_large_results: Optional[bool] = None, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query_or_table) return global_session.with_default_session( @@ -231,6 +234,7 @@ def read_gbq( use_cache=use_cache, col_order=col_order, dry_run=dry_run, + allow_large_results=allow_large_results, ) @@ -400,6 +404,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., + allow_large_results: Optional[bool] = ..., ) -> bigframes.dataframe.DataFrame: ... @@ -416,6 +421,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: vendored_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -431,6 +437,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: vendored_pandas_gbq.FiltersType = (), dry_run: bool = False, + allow_large_results: Optional[bool] = None, ) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query) return global_session.with_default_session( @@ -444,6 +451,7 @@ def read_gbq_query( col_order=col_order, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) @@ -617,7 +625,11 @@ def from_glob_path( def _get_bqclient() -> bigquery.Client: - clients_provider = bigframes.session.clients.ClientsProvider( + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import clients + + clients_provider = clients.ClientsProvider( project=config.options.bigquery.project, location=config.options.bigquery.location, use_regional_endpoints=config.options.bigquery.use_regional_endpoints, @@ -631,11 +643,15 @@ def _get_bqclient() -> bigquery.Client: def _dry_run(query, bqclient) -> bigquery.QueryJob: + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import metrics as bf_metrics + job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True)) # Fix for b/435183833. Log metrics even if a Session isn't available. - if bigframes.session.metrics.LOGGING_NAME_ENV_VAR in os.environ: - metrics = bigframes.session.metrics.ExecutionMetrics() + if bf_metrics.LOGGING_NAME_ENV_VAR in os.environ: + metrics = bf_metrics.ExecutionMetrics() metrics.count_job_stats(job) return job @@ -645,6 +661,10 @@ def _set_default_session_location_if_possible(query): def _set_default_session_location_if_possible_deferred_query(create_query): + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session._io import bigquery + # Set the location as per the query if this is the first query the user is # running and: # (1) Default session has not started yet, and @@ -666,7 +686,7 @@ def _set_default_session_location_if_possible_deferred_query(create_query): query = create_query() bqclient = _get_bqclient() - if bigframes.session._io.bigquery.is_query(query): + if bigquery.is_query(query): # Intentionally run outside of the session so that we can detect the # location before creating the session. Since it's a dry_run, labels # aren't necessary. diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 66b0196286..432e73159a 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -62,6 +62,7 @@ from bigframes import exceptions as bfe from bigframes import version +import bigframes._config import bigframes._config.bigquery_options as bigquery_options import bigframes.clients import bigframes.constants @@ -134,6 +135,10 @@ def __init__( context: Optional[bigquery_options.BigQueryOptions] = None, clients_provider: Optional[bigframes.session.clients.ClientsProvider] = None, ): + # Address circular imports in doctest due to bigframes/session/__init__.py + # containing a lot of logic and samples. + from bigframes.session import anonymous_dataset, clients, loader, metrics + _warn_if_bf_version_is_obsolete() if context is None: @@ -169,7 +174,7 @@ def __init__( if clients_provider: self._clients_provider = clients_provider else: - self._clients_provider = bigframes.session.clients.ClientsProvider( + self._clients_provider = clients.ClientsProvider( project=context.project, location=self._location, use_regional_endpoints=context.use_regional_endpoints, @@ -221,15 +226,13 @@ def __init__( else bigframes.enums.DefaultIndexKind.NULL ) - self._metrics = bigframes.session.metrics.ExecutionMetrics() + self._metrics = metrics.ExecutionMetrics() self._function_session = bff_session.FunctionSession() - self._anon_dataset_manager = ( - bigframes.session.anonymous_dataset.AnonymousDatasetManager( - self._clients_provider.bqclient, - location=self._location, - session_id=self._session_id, - kms_key=self._bq_kms_key_name, - ) + self._anon_dataset_manager = anonymous_dataset.AnonymousDatasetManager( + self._clients_provider.bqclient, + location=self._location, + session_id=self._session_id, + kms_key=self._bq_kms_key_name, ) # Session temp tables don't support specifying kms key, so use anon dataset if kms key specified self._session_resource_manager = ( @@ -243,7 +246,7 @@ def __init__( self._temp_storage_manager = ( self._session_resource_manager or self._anon_dataset_manager ) - self._loader = bigframes.session.loader.GbqDataLoader( + self._loader = loader.GbqDataLoader( session=self, bqclient=self._clients_provider.bqclient, storage_manager=self._temp_storage_manager, @@ -397,6 +400,7 @@ def read_gbq( # type: ignore[overload-overlap] use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[False] = ..., + allow_large_results: Optional[bool] = ..., ) -> dataframe.DataFrame: ... @@ -413,6 +417,7 @@ def read_gbq( use_cache: Optional[bool] = ..., col_order: Iterable[str] = ..., dry_run: Literal[True] = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -427,8 +432,8 @@ def read_gbq( filters: third_party_pandas_gbq.FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), - dry_run: bool = False - # Add a verify index argument that fails if the index is not unique. + dry_run: bool = False, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame | pandas.Series: # TODO(b/281571214): Generate prompt to show the progress of read_gbq. if columns and col_order: @@ -438,6 +443,9 @@ def read_gbq( elif col_order: columns = col_order + if allow_large_results is None: + allow_large_results = bigframes._config.options._allow_large_results + if bf_io_bigquery.is_query(query_or_table): return self._loader.read_gbq_query( # type: ignore # for dry_run overload query_or_table, @@ -448,6 +456,7 @@ def read_gbq( use_cache=use_cache, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) else: if configuration is not None: @@ -523,6 +532,8 @@ def _read_gbq_colab( if pyformat_args is None: pyformat_args = {} + allow_large_results = bigframes._config.options._allow_large_results + query = bigframes.core.pyformat.pyformat( query, pyformat_args=pyformat_args, @@ -535,10 +546,7 @@ def _read_gbq_colab( index_col=bigframes.enums.DefaultIndexKind.NULL, force_total_order=False, dry_run=typing.cast(Union[Literal[False], Literal[True]], dry_run), - # TODO(tswast): we may need to allow allow_large_results to be overwritten - # or possibly a general configuration object for an explicit - # destination table and write disposition. - allow_large_results=False, + allow_large_results=allow_large_results, ) @overload @@ -554,6 +562,7 @@ def read_gbq_query( # type: ignore[overload-overlap] col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., + allow_large_results: Optional[bool] = ..., ) -> dataframe.DataFrame: ... @@ -570,6 +579,7 @@ def read_gbq_query( col_order: Iterable[str] = ..., filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., + allow_large_results: Optional[bool] = ..., ) -> pandas.Series: ... @@ -585,6 +595,7 @@ def read_gbq_query( col_order: Iterable[str] = (), filters: third_party_pandas_gbq.FiltersType = (), dry_run: bool = False, + allow_large_results: Optional[bool] = None, ) -> dataframe.DataFrame | pandas.Series: """Turn a SQL query into a DataFrame. @@ -634,9 +645,48 @@ def read_gbq_query( See also: :meth:`Session.read_gbq`. + Args: + query (str): + A SQL query to execute. + index_col (Iterable[str] or str, optional): + The column(s) to use as the index for the DataFrame. This can be + a single column name or a list of column names. If not provided, + a default index will be used. + columns (Iterable[str], optional): + The columns to read from the query result. If not + specified, all columns will be read. + configuration (dict, optional): + A dictionary of query job configuration options. See the + BigQuery REST API documentation for a list of available options: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query + max_results (int, optional): + The maximum number of rows to retrieve from the query + result. If not specified, all rows will be loaded. + use_cache (bool, optional): + Whether to use cached results for the query. Defaults to ``True``. + Setting this to ``False`` will force a re-execution of the query. + col_order (Iterable[str], optional): + The desired order of columns in the resulting DataFrame. This + parameter is deprecated and will be removed in a future version. + Use ``columns`` instead. + filters (list[tuple], optional): + A list of filters to apply to the data. Filters are specified + as a list of tuples, where each tuple contains a column name, + an operator (e.g., '==', '!='), and a value. + dry_run (bool, optional): + If ``True``, the function will not actually execute the query but + will instead return statistics about the query. Defaults to + ``False``. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. + Defaults to ``bpd.options.compute.allow_large_results``. + Returns: - bigframes.pandas.DataFrame: - A DataFrame representing results of the query or table. + bigframes.pandas.DataFrame or pandas.Series: + A DataFrame representing the result of the query. If ``dry_run`` + is ``True``, a ``pandas.Series`` containing query statistics is + returned. Raises: ValueError: @@ -651,6 +701,9 @@ def read_gbq_query( elif col_order: columns = col_order + if allow_large_results is None: + allow_large_results = bigframes._config.options._allow_large_results + return self._loader.read_gbq_query( # type: ignore # for dry_run overload query=query, index_col=index_col, @@ -660,6 +713,7 @@ def read_gbq_query( use_cache=use_cache, filters=filters, dry_run=dry_run, + allow_large_results=allow_large_results, ) @overload @@ -717,9 +771,40 @@ def read_gbq_table( See also: :meth:`Session.read_gbq`. + Args: + table_id (str): + The identifier of the BigQuery table to read. + index_col (Iterable[str] or str, optional): + The column(s) to use as the index for the DataFrame. This can be + a single column name or a list of column names. If not provided, + a default index will be used. + columns (Iterable[str], optional): + The columns to read from the table. If not specified, all + columns will be read. + max_results (int, optional): + The maximum number of rows to retrieve from the table. If not + specified, all rows will be loaded. + filters (list[tuple], optional): + A list of filters to apply to the data. Filters are specified + as a list of tuples, where each tuple contains a column name, + an operator (e.g., '==', '!='), and a value. + use_cache (bool, optional): + Whether to use cached results for the query. Defaults to ``True``. + Setting this to ``False`` will force a re-execution of the query. + col_order (Iterable[str], optional): + The desired order of columns in the resulting DataFrame. This + parameter is deprecated and will be removed in a future version. + Use ``columns`` instead. + dry_run (bool, optional): + If ``True``, the function will not actually execute the query but + will instead return statistics about the table. Defaults to + ``False``. + Returns: - bigframes.pandas.DataFrame: - A DataFrame representing results of the query or table. + bigframes.pandas.DataFrame or pandas.Series: + A DataFrame representing the contents of the table. If + ``dry_run`` is ``True``, a ``pandas.Series`` containing table + statistics is returned. Raises: ValueError: diff --git a/bigframes/session/_io/bigquery/read_gbq_query.py b/bigframes/session/_io/bigquery/read_gbq_query.py index 70c83d7875..aed77615ce 100644 --- a/bigframes/session/_io/bigquery/read_gbq_query.py +++ b/bigframes/session/_io/bigquery/read_gbq_query.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Optional +from typing import cast, Iterable, Optional, Tuple from google.cloud import bigquery import google.cloud.bigquery.table @@ -28,6 +28,7 @@ import bigframes.core.blocks as blocks import bigframes.core.guid import bigframes.core.schema as schemata +import bigframes.enums import bigframes.session @@ -53,7 +54,11 @@ def create_dataframe_from_query_job_stats( def create_dataframe_from_row_iterator( - rows: google.cloud.bigquery.table.RowIterator, *, session: bigframes.session.Session + rows: google.cloud.bigquery.table.RowIterator, + *, + session: bigframes.session.Session, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind, + columns: Iterable[str], ) -> dataframe.DataFrame: """Convert a RowIterator into a DataFrame wrapping a LocalNode. @@ -61,11 +66,27 @@ def create_dataframe_from_row_iterator( 'jobless' case where there's no destination table. """ pa_table = rows.to_arrow() + bq_schema = list(rows.schema) + is_default_index = not index_col or isinstance( + index_col, bigframes.enums.DefaultIndexKind + ) - # TODO(tswast): Use array_value.promote_offsets() instead once that node is - # supported by the local engine. - offsets_col = bigframes.core.guid.generate_guid() - pa_table = pyarrow_utils.append_offsets(pa_table, offsets_col=offsets_col) + if is_default_index: + # We get a sequential index for free, so use that if no index is specified. + # TODO(tswast): Use array_value.promote_offsets() instead once that node is + # supported by the local engine. + offsets_col = bigframes.core.guid.generate_guid() + pa_table = pyarrow_utils.append_offsets(pa_table, offsets_col=offsets_col) + bq_schema += [bigquery.SchemaField(offsets_col, "INTEGER")] + index_columns: Tuple[str, ...] = (offsets_col,) + index_labels: Tuple[Optional[str], ...] = (None,) + elif isinstance(index_col, str): + index_columns = (index_col,) + index_labels = (index_col,) + else: + index_col = cast(Iterable[str], index_col) + index_columns = tuple(index_col) + index_labels = cast(Tuple[Optional[str], ...], tuple(index_col)) # We use the ManagedArrowTable constructor directly, because the # results of to_arrow() should be the source of truth with regards @@ -74,17 +95,27 @@ def create_dataframe_from_row_iterator( # like the output of the BQ Storage Read API. mat = local_data.ManagedArrowTable( pa_table, - schemata.ArraySchema.from_bq_schema( - list(rows.schema) + [bigquery.SchemaField(offsets_col, "INTEGER")] - ), + schemata.ArraySchema.from_bq_schema(bq_schema), ) mat.validate() + column_labels = [ + field.name for field in rows.schema if field.name not in index_columns + ] + array_value = core.ArrayValue.from_managed(mat, session) block = blocks.Block( array_value, - (offsets_col,), - [field.name for field in rows.schema], - (None,), + index_columns=index_columns, + column_labels=column_labels, + index_labels=index_labels, ) - return dataframe.DataFrame(block) + df = dataframe.DataFrame(block) + + if columns: + df = df[list(columns)] + + if not is_default_index: + df = df.sort_index() + + return df diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 6500701324..49b1195235 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -721,6 +721,9 @@ def read_gbq_table( columns=columns, use_cache=use_cache, dry_run=dry_run, + # If max_results has been set, we almost certainly have < 10 GB + # of results. + allow_large_results=False, ) return df @@ -895,7 +898,7 @@ def read_gbq_query( # type: ignore[overload-overlap] filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[False] = ..., force_total_order: Optional[bool] = ..., - allow_large_results: bool = ..., + allow_large_results: bool, ) -> dataframe.DataFrame: ... @@ -912,7 +915,7 @@ def read_gbq_query( filters: third_party_pandas_gbq.FiltersType = ..., dry_run: Literal[True] = ..., force_total_order: Optional[bool] = ..., - allow_large_results: bool = ..., + allow_large_results: bool, ) -> pandas.Series: ... @@ -928,7 +931,7 @@ def read_gbq_query( filters: third_party_pandas_gbq.FiltersType = (), dry_run: bool = False, force_total_order: Optional[bool] = None, - allow_large_results: bool = True, + allow_large_results: bool, ) -> dataframe.DataFrame | pandas.Series: configuration = _transform_read_gbq_configuration(configuration) @@ -953,6 +956,7 @@ def read_gbq_query( True if use_cache is None else use_cache ) + _check_duplicates("columns", columns) index_cols = _to_index_cols(index_col) _check_index_col_param(index_cols, columns) @@ -1040,10 +1044,19 @@ def read_gbq_query( # local node. Likely there are a wide range of sizes in which it # makes sense to download the results beyond the first page, even if # there is a job and destination table available. - if rows is not None and destination is None: + if ( + rows is not None + and destination is None + and ( + query_job_for_metrics is None + or query_job_for_metrics.statement_type == "SELECT" + ) + ): return bf_read_gbq_query.create_dataframe_from_row_iterator( rows, session=self._session, + index_col=index_col, + columns=columns, ) # If there was no destination table and we've made it this far, that diff --git a/tests/system/small/bigquery/test_vector_search.py b/tests/system/small/bigquery/test_vector_search.py index a282135fa6..3107795730 100644 --- a/tests/system/small/bigquery/test_vector_search.py +++ b/tests/system/small/bigquery/test_vector_search.py @@ -123,12 +123,17 @@ def test_vector_search_basic_params_with_df(): "embedding": [[1.0, 2.0], [3.0, 5.2]], } ) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - top_k=2, - ).to_pandas() # type:ignore + vector_search_result = ( + bbq.vector_search( + base_table="bigframes-dev.bigframes_tests_sys.base_table", + column_to_search="my_embedding", + query=search_query, + top_k=2, + ) + .sort_values("distance") + .sort_index() + .to_pandas() + ) # type:ignore expected = pd.DataFrame( { "query_id": ["cat", "dog", "dog", "cat"], @@ -157,80 +162,60 @@ def test_vector_search_basic_params_with_df(): ) -def test_vector_search_different_params_with_query(): - search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]]) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - distance_type="cosine", - top_k=2, - ).to_pandas() # type:ignore - expected = pd.DataFrame( +def test_vector_search_different_params_with_query(session): + base_df = bpd.DataFrame( { - "0": [ - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([3.0, 5.2]), - np.array([3.0, 5.2]), - ], - "id": [2, 1, 1, 2], + "id": [1, 2, 3, 4], "my_embedding": [ - np.array([2.0, 4.0]), - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([2.0, 4.0]), + np.array([0.0, 1.0]), + np.array([1.0, 0.0]), + np.array([0.0, -1.0]), + np.array([-1.0, 0.0]), ], - "distance": [0.0, 0.0, 0.001777, 0.001777], }, - index=pd.Index([0, 0, 1, 1], dtype="Int64"), - ) - pd.testing.assert_frame_equal( - vector_search_result, expected, check_dtype=False, rtol=0.1 - ) - - -def test_vector_search_df_with_query_column_to_search(): - search_query = bpd.DataFrame( - { - "query_id": ["dog", "cat"], - "embedding": [[1.0, 2.0], [3.0, 5.2]], - "another_embedding": [[1.0, 2.5], [3.3, 5.2]], - } - ) - vector_search_result = bbq.vector_search( - base_table="bigframes-dev.bigframes_tests_sys.base_table", - column_to_search="my_embedding", - query=search_query, - query_column_to_search="another_embedding", - top_k=2, - ).to_pandas() # type:ignore - expected = pd.DataFrame( - { - "query_id": ["dog", "dog", "cat", "cat"], - "embedding": [ - np.array([1.0, 2.0]), - np.array([1.0, 2.0]), - np.array([3.0, 5.2]), - np.array([3.0, 5.2]), - ], - "another_embedding": [ - np.array([1.0, 2.5]), - np.array([1.0, 2.5]), - np.array([3.3, 5.2]), - np.array([3.3, 5.2]), - ], - "id": [1, 4, 2, 5], - "my_embedding": [ - np.array([1.0, 2.0]), - np.array([1.0, 3.2]), - np.array([2.0, 4.0]), - np.array([5.0, 5.4]), - ], - "distance": [0.5, 0.7, 1.769181, 1.711724], - }, - index=pd.Index([0, 0, 1, 1], dtype="Int64"), - ) - pd.testing.assert_frame_equal( - vector_search_result, expected, check_dtype=False, rtol=0.1 + session=session, ) + base_table = base_df.to_gbq() + try: + search_query = bpd.Series([[0.75, 0.25], [-0.25, -0.75]], session=session) + vector_search_result = ( + bbq.vector_search( + base_table=base_table, + column_to_search="my_embedding", + query=search_query, + distance_type="cosine", + top_k=2, + ) + .sort_values("distance") + .sort_index() + .to_pandas() + ) # type:ignore + expected = pd.DataFrame( + { + "0": [ + [0.75, 0.25], + [0.75, 0.25], + [-0.25, -0.75], + [-0.25, -0.75], + ], + "id": [2, 1, 3, 4], + "my_embedding": [ + [1.0, 0.0], + [0.0, 1.0], + [0.0, -1.0], + [-1.0, 0.0], + ], + "distance": [ + 0.051317, + 0.683772, + 0.051317, + 0.683772, + ], + }, + index=pd.Index([0, 0, 1, 1], dtype="Int64"), + ) + pd.testing.assert_frame_equal( + vector_search_result, expected, check_dtype=False, rtol=0.1 + ) + finally: + session.bqclient.delete_table(base_table, not_found_ok=True) diff --git a/tests/system/small/ml/test_forecasting.py b/tests/system/small/ml/test_forecasting.py index d1b6b18fbe..134f82e96e 100644 --- a/tests/system/small/ml/test_forecasting.py +++ b/tests/system/small/ml/test_forecasting.py @@ -432,8 +432,10 @@ def test_arima_plus_detect_anomalies_params( }, ) pd.testing.assert_frame_equal( - anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]], - expected, + anomalies[["is_anomaly", "lower_bound", "upper_bound", "anomaly_probability"]] + .sort_values("anomaly_probability") + .reset_index(drop=True), + expected.sort_values("anomaly_probability").reset_index(drop=True), rtol=0.1, check_index_type=False, check_dtype=False, @@ -449,11 +451,16 @@ def test_arima_plus_score( id_col_name, ): if id_col_name: - result = time_series_arima_plus_model_w_id.score( - new_time_series_df_w_id[["parsed_date"]], - new_time_series_df_w_id[["total_visits"]], - new_time_series_df_w_id[["id"]], - ).to_pandas() + result = ( + time_series_arima_plus_model_w_id.score( + new_time_series_df_w_id[["parsed_date"]], + new_time_series_df_w_id[["total_visits"]], + new_time_series_df_w_id[["id"]], + ) + .to_pandas() + .sort_values("id") + .reset_index(drop=True) + ) else: result = time_series_arima_plus_model.score( new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]] @@ -472,6 +479,8 @@ def test_arima_plus_score( ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") + expected = expected.sort_values("id") + expected = expected.reset_index(drop=True) else: expected = pd.DataFrame( { @@ -488,6 +497,7 @@ def test_arima_plus_score( expected, rtol=0.1, check_index_type=False, + check_dtype=False, ) @@ -542,11 +552,16 @@ def test_arima_plus_score_series( id_col_name, ): if id_col_name: - result = time_series_arima_plus_model_w_id.score( - new_time_series_df_w_id["parsed_date"], - new_time_series_df_w_id["total_visits"], - new_time_series_df_w_id["id"], - ).to_pandas() + result = ( + time_series_arima_plus_model_w_id.score( + new_time_series_df_w_id["parsed_date"], + new_time_series_df_w_id["total_visits"], + new_time_series_df_w_id["id"], + ) + .to_pandas() + .sort_values("id") + .reset_index(drop=True) + ) else: result = time_series_arima_plus_model.score( new_time_series_df["parsed_date"], new_time_series_df["total_visits"] @@ -565,6 +580,8 @@ def test_arima_plus_score_series( ) expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True) expected["id"] = expected["id"].astype("string[pyarrow]") + expected = expected.sort_values("id") + expected = expected.reset_index(drop=True) else: expected = pd.DataFrame( { @@ -581,6 +598,7 @@ def test_arima_plus_score_series( expected, rtol=0.1, check_index_type=False, + check_dtype=False, ) diff --git a/tests/system/small/ml/test_preprocessing.py b/tests/system/small/ml/test_preprocessing.py index 34be48be1e..65a851efc3 100644 --- a/tests/system/small/ml/test_preprocessing.py +++ b/tests/system/small/ml/test_preprocessing.py @@ -245,7 +245,7 @@ def test_max_abs_scaler_save_load(new_penguins_df, dataset_id): index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"), ) - pd.testing.assert_frame_equal(result, expected, rtol=0.1) + pd.testing.assert_frame_equal(result.sort_index(), expected.sort_index(), rtol=0.1) def test_min_max_scaler_normalized_fit_transform(new_penguins_df): diff --git a/tests/system/small/session/test_read_gbq_query.py b/tests/system/small/session/test_read_gbq_query.py new file mode 100644 index 0000000000..c1408febca --- /dev/null +++ b/tests/system/small/session/test_read_gbq_query.py @@ -0,0 +1,113 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +import bigframes +import bigframes.core.nodes as nodes + + +def test_read_gbq_query_w_allow_large_results(session: bigframes.Session): + if not hasattr(session.bqclient, "default_job_creation_mode"): + pytest.skip("Jobless query only available on newer google-cloud-bigquery.") + + query = "SELECT 1" + + # Make sure we don't get a cached table. + configuration = {"query": {"useQueryCache": False}} + + # Very small results should wrap a local node. + df_false = session.read_gbq( + query, + configuration=configuration, + allow_large_results=False, + ) + assert df_false.shape == (1, 1) + roots_false = df_false._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadLocalNode) for node in roots_false) + assert not any(isinstance(node, nodes.ReadTableNode) for node in roots_false) + + # Large results allowed should wrap a table. + df_true = session.read_gbq( + query, + configuration=configuration, + allow_large_results=True, + ) + assert df_true.shape == (1, 1) + roots_true = df_true._get_block().expr.node.roots + assert any(isinstance(node, nodes.ReadTableNode) for node in roots_true) + + +def test_read_gbq_query_w_columns(session: bigframes.Session): + query = """ + SELECT 1 as int_col, + 'a' as str_col, + TIMESTAMP('2025-08-21 10:41:32.123456') as timestamp_col + """ + + result = session.read_gbq( + query, + columns=["timestamp_col", "int_col"], + ) + assert list(result.columns) == ["timestamp_col", "int_col"] + assert result.to_dict(orient="records") == [ + { + "timestamp_col": datetime.datetime( + 2025, 8, 21, 10, 41, 32, 123456, tzinfo=datetime.timezone.utc + ), + "int_col": 1, + } + ] + + +@pytest.mark.parametrize( + ("index_col", "expected_index_names"), + ( + pytest.param( + "my_custom_index", + ("my_custom_index",), + id="string", + ), + pytest.param( + ("my_custom_index",), + ("my_custom_index",), + id="iterable", + ), + pytest.param( + ("my_custom_index", "int_col"), + ("my_custom_index", "int_col"), + id="multiindex", + ), + ), +) +def test_read_gbq_query_w_index_col( + session: bigframes.Session, index_col, expected_index_names +): + query = """ + SELECT 1 as int_col, + 'a' as str_col, + 0 as my_custom_index, + TIMESTAMP('2025-08-21 10:41:32.123456') as timestamp_col + """ + + result = session.read_gbq( + query, + index_col=index_col, + ) + assert tuple(result.index.names) == expected_index_names + assert frozenset(result.columns) == frozenset( + {"int_col", "str_col", "my_custom_index", "timestamp_col"} + ) - frozenset(expected_index_names) diff --git a/tests/system/small/test_pandas_options.py b/tests/system/small/test_pandas_options.py index 1d360e0d4f..7a750ddfd3 100644 --- a/tests/system/small/test_pandas_options.py +++ b/tests/system/small/test_pandas_options.py @@ -280,6 +280,17 @@ def test_credentials_need_reauthentication( session = bpd.get_global_session() assert session.bqclient._http.credentials.valid + # We look at the thread-local session because of the + # reset_default_session_and_location fixture and that this test mutates + # state that might otherwise be used by tests running in parallel. + current_session = ( + bigframes.core.global_session._global_session_state.thread_local_session + ) + assert current_session is not None + + # Force a temp table to be created, so there is something to cleanup. + current_session._anon_dataset_manager.create_temp_table(schema=()) + with monkeypatch.context() as m: # Simulate expired credentials to trigger the credential refresh flow m.setattr( @@ -303,15 +314,6 @@ def test_credentials_need_reauthentication( with pytest.raises(google.auth.exceptions.RefreshError): bpd.read_gbq(test_query) - # Now verify that closing the session works We look at the - # thread-local session because of the - # reset_default_session_and_location fixture and that this test mutates - # state that might otherwise be used by tests running in parallel. - assert ( - bigframes.core.global_session._global_session_state.thread_local_session - is not None - ) - with warnings.catch_warnings(record=True) as warned: bpd.close_session() # CleanupFailedWarning: can't clean up diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index f0a6302c7b..6343f0cc53 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -619,7 +619,7 @@ def test_read_gbq_wildcard( pytest.param( {"query": {"useQueryCache": False, "maximumBytesBilled": "100"}}, marks=pytest.mark.xfail( - raises=google.api_core.exceptions.InternalServerError, + raises=google.api_core.exceptions.BadRequest, reason="Expected failure when the query exceeds the maximum bytes billed limit.", ), ), diff --git a/tests/system/small/test_unordered.py b/tests/system/small/test_unordered.py index ccb2140799..867067a161 100644 --- a/tests/system/small/test_unordered.py +++ b/tests/system/small/test_unordered.py @@ -103,7 +103,7 @@ def test_unordered_mode_read_gbq(unordered_session): } ) # Don't need ignore_order as there is only 1 row - assert_pandas_df_equal(df.to_pandas(), expected) + assert_pandas_df_equal(df.to_pandas(), expected, check_index_type=False) @pytest.mark.parametrize( diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 10fefcc457..7f6843aacf 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -143,9 +143,10 @@ def test_linear_regression_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -154,8 +155,9 @@ def test_linear_regression_score(mock_session, bqml_model, mock_X, mock_y): model._bqml_model = bqml_model model.score(mock_X, mock_y) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))", + allow_large_results=True, ) @@ -167,7 +169,7 @@ def test_logistic_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql", ) @@ -198,9 +200,10 @@ def test_logistic_regression_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.PREDICT(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -209,8 +212,9 @@ def test_logistic_regression_score(mock_session, bqml_model, mock_X, mock_y): model._bqml_model = bqml_model model.score(mock_X, mock_y) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_y_sql))", + allow_large_results=True, ) @@ -243,9 +247,10 @@ def test_decomposition_mf_predict(mock_session, bqml_model, mock_X): model._bqml_model = bqml_model model.predict(mock_X) - mock_session.read_gbq.assert_called_once_with( + mock_session.read_gbq_query.assert_called_once_with( "SELECT * FROM ML.RECOMMEND(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql))", index_col=["index_column_id"], + allow_large_results=True, ) @@ -260,8 +265,9 @@ def test_decomposition_mf_score(mock_session, bqml_model): ) model._bqml_model = bqml_model model.score() - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)", + allow_large_results=True, ) @@ -276,6 +282,7 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X): ) model._bqml_model = bqml_model model.score(mock_X) - mock_session.read_gbq.assert_called_once_with( - "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))" + mock_session.read_gbq_query.assert_called_once_with( + "SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))", + allow_large_results=True, ) diff --git a/tests/unit/session/test_read_gbq_query.py b/tests/unit/session/test_read_gbq_query.py index afd9922426..1f9d2fb945 100644 --- a/tests/unit/session/test_read_gbq_query.py +++ b/tests/unit/session/test_read_gbq_query.py @@ -25,7 +25,7 @@ def test_read_gbq_query_sets_destination_table(): # Use partial ordering mode to skip column uniqueness checks. session = mocks.create_bigquery_session(ordering_mode="partial") - _ = session.read_gbq_query("SELECT 'my-test-query';") + _ = session.read_gbq_query("SELECT 'my-test-query';", allow_large_results=True) queries = session._queries # type: ignore configs = session._job_configs # type: ignore diff --git a/third_party/bigframes_vendored/pandas/io/gbq.py b/third_party/bigframes_vendored/pandas/io/gbq.py index 3dae2b6bbe..0fdca4dde1 100644 --- a/third_party/bigframes_vendored/pandas/io/gbq.py +++ b/third_party/bigframes_vendored/pandas/io/gbq.py @@ -25,6 +25,7 @@ def read_gbq( filters: FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), + allow_large_results: Optional[bool] = None, ): """Loads a DataFrame from BigQuery. @@ -156,6 +157,11 @@ def read_gbq( `configuration` to avoid conflicts. col_order (Iterable[str]): Alias for columns, retained for backwards compatibility. + allow_large_results (bool, optional): + Whether to allow large query results. If ``True``, the query + results can be larger than the maximum response size. This + option is only applicable when ``query_or_table`` is a query. + Defaults to ``bpd.options.compute.allow_large_results``. Raises: bigframes.exceptions.DefaultIndexWarning: From cbbbce335544004bd7bc7acee99c82ff23c8eaee Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Tue, 2 Sep 2025 13:02:55 -0700 Subject: [PATCH 25/28] refactor: Unify bigquery execution paths (#2007) --- bigframes/core/blocks.py | 66 +++- bigframes/core/indexes/base.py | 11 +- bigframes/dataframe.py | 56 +-- bigframes/session/bq_caching_executor.py | 446 ++++++++++------------- bigframes/session/execution_spec.py | 53 +++ bigframes/session/executor.py | 47 +-- bigframes/testing/compiler_session.py | 7 + bigframes/testing/polars_session.py | 37 +- tests/system/small/test_session.py | 11 +- 9 files changed, 375 insertions(+), 359 deletions(-) create mode 100644 bigframes/session/execution_spec.py diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index f7d456bf9d..07d7e4c45b 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -69,7 +69,7 @@ import bigframes.exceptions as bfe import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops -from bigframes.session import dry_runs +from bigframes.session import dry_runs, execution_spec from bigframes.session import executor as executors # Type constraint for wherever column labels are used @@ -257,7 +257,10 @@ def shape(self) -> typing.Tuple[int, int]: except Exception: pass - row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar() + row_count = self.session._executor.execute( + self.expr.row_count(), + execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False), + ).to_py_scalar() return (row_count, len(self.value_columns)) @property @@ -557,8 +560,17 @@ def to_arrow( allow_large_results: Optional[bool] = None, ) -> Tuple[pa.Table, Optional[bigquery.QueryJob]]: """Run query and download results as a pyarrow Table.""" + under_10gb = ( + (not allow_large_results) + if (allow_large_results is not None) + else not bigframes.options._allow_large_results + ) execute_result = self.session._executor.execute( - self.expr, ordered=ordered, use_explicit_destination=allow_large_results + self.expr, + execution_spec.ExecutionSpec( + promise_under_10gb=under_10gb, + ordered=ordered, + ), ) pa_table = execute_result.to_arrow_table() @@ -647,8 +659,15 @@ def try_peek( self, n: int = 20, force: bool = False, allow_large_results=None ) -> typing.Optional[pd.DataFrame]: if force or self.expr.supports_fast_peek: - result = self.session._executor.peek( - self.expr, n, use_explicit_destination=allow_large_results + # really, we should just block insane peek values and always assume <10gb + under_10gb = ( + (not allow_large_results) + if (allow_large_results is not None) + else not bigframes.options._allow_large_results + ) + result = self.session._executor.execute( + self.expr, + execution_spec.ExecutionSpec(promise_under_10gb=under_10gb, peek=n), ) df = result.to_pandas() return self._copy_index_to_pandas(df) @@ -665,10 +684,18 @@ def to_pandas_batches( page_size and max_results determine the size and number of batches, see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJob#google_cloud_bigquery_job_QueryJob_result""" + + under_10gb = ( + (not allow_large_results) + if (allow_large_results is not None) + else not bigframes.options._allow_large_results + ) execute_result = self.session._executor.execute( self.expr, - ordered=True, - use_explicit_destination=allow_large_results, + execution_spec.ExecutionSpec( + promise_under_10gb=under_10gb, + ordered=True, + ), ) # To reduce the number of edge cases to consider when working with the @@ -714,10 +741,17 @@ def _materialize_local( ) -> Tuple[pd.DataFrame, Optional[bigquery.QueryJob]]: """Run query and download results as a pandas DataFrame. Return the total number of results as well.""" # TODO(swast): Allow for dry run and timeout. + under_10gb = ( + (not materialize_options.allow_large_results) + if (materialize_options.allow_large_results is not None) + else (not bigframes.options._allow_large_results) + ) execute_result = self.session._executor.execute( self.expr, - ordered=materialize_options.ordered, - use_explicit_destination=materialize_options.allow_large_results, + execution_spec.ExecutionSpec( + promise_under_10gb=under_10gb, + ordered=materialize_options.ordered, + ), ) sample_config = materialize_options.downsampling if execute_result.total_bytes is not None: @@ -1598,9 +1632,19 @@ def retrieve_repr_request_results( config=executors.CacheConfig(optimize_for="head", if_cached="reuse-strict"), ) head_result = self.session._executor.execute( - self.expr.slice(start=None, stop=max_results, step=None) + self.expr.slice(start=None, stop=max_results, step=None), + execution_spec.ExecutionSpec( + promise_under_10gb=True, + ordered=True, + ), ) - row_count = self.session._executor.execute(self.expr.row_count()).to_py_scalar() + row_count = self.session._executor.execute( + self.expr.row_count(), + execution_spec.ExecutionSpec( + promise_under_10gb=True, + ordered=False, + ), + ).to_py_scalar() head_df = head_result.to_pandas() return self._copy_index_to_pandas(head_df), row_count, head_result.query_job diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index e022b3f151..f8ec38621d 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -38,6 +38,7 @@ import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops import bigframes.series +import bigframes.session.execution_spec as ex_spec if typing.TYPE_CHECKING: import bigframes.dataframe @@ -283,8 +284,9 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: # Check if key exists at all by counting count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id)) count_result = filtered_block._expr.aggregate([(count_agg, "count")]) + count_scalar = self._block.session._executor.execute( - count_result + count_result, ex_spec.ExecutionSpec(promise_under_10gb=True) ).to_py_scalar() if count_scalar == 0: @@ -295,7 +297,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)) position_result = filtered_block._expr.aggregate([(min_agg, "position")]) position_scalar = self._block.session._executor.execute( - position_result + position_result, ex_spec.ExecutionSpec(promise_under_10gb=True) ).to_py_scalar() return int(position_scalar) @@ -326,7 +328,10 @@ def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice: combined_result = filtered_block._expr.aggregate(min_max_aggs) # Execute query and extract positions - result_df = self._block.session._executor.execute(combined_result).to_pandas() + result_df = self._block.session._executor.execute( + combined_result, + execution_spec=ex_spec.ExecutionSpec(promise_under_10gb=True), + ).to_pandas() min_pos = int(result_df["min_pos"].iloc[0]) max_pos = int(result_df["max_pos"].iloc[0]) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 75be1c256e..f9de117b29 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -86,6 +86,7 @@ import bigframes.operations.structs import bigframes.series import bigframes.session._io.bigquery +import bigframes.session.execution_spec as ex_spec if typing.TYPE_CHECKING: from _typeshed import SupportsRichComparison @@ -4268,17 +4269,19 @@ def to_csv( index=index and self._has_index, ordering_id=bigframes.session._io.bigquery.IO_ORDERING_ID, ) - options = { + options: dict[str, Union[bool, str]] = { "field_delimiter": sep, "header": header, } - query_job = self._session._executor.export_gcs( + result = self._session._executor.execute( export_array.rename_columns(id_overrides), - path_or_buf, - format="csv", - export_options=options, + ex_spec.ExecutionSpec( + ex_spec.GcsOutputSpec( + uri=path_or_buf, format="csv", export_options=tuple(options.items()) + ) + ), ) - self._set_internal_query_job(query_job) + self._set_internal_query_job(result.query_job) return None def to_json( @@ -4321,13 +4324,13 @@ def to_json( index=index and self._has_index, ordering_id=bigframes.session._io.bigquery.IO_ORDERING_ID, ) - query_job = self._session._executor.export_gcs( + result = self._session._executor.execute( export_array.rename_columns(id_overrides), - path_or_buf, - format="json", - export_options={}, + ex_spec.ExecutionSpec( + ex_spec.GcsOutputSpec(uri=path_or_buf, format="json", export_options=()) + ), ) - self._set_internal_query_job(query_job) + self._set_internal_query_job(result.query_job) return None def to_gbq( @@ -4400,16 +4403,21 @@ def to_gbq( ) ) - query_job = self._session._executor.export_gbq( + result = self._session._executor.execute( export_array.rename_columns(id_overrides), - destination=destination, - cluster_cols=clustering_fields, - if_exists=if_exists, + ex_spec.ExecutionSpec( + ex_spec.TableOutputSpec( + destination, + cluster_cols=tuple(clustering_fields), + if_exists=if_exists, + ) + ), ) - self._set_internal_query_job(query_job) + assert result.query_job is not None + self._set_internal_query_job(result.query_job) # The query job should have finished, so there should be always be a result table. - result_table = query_job.destination + result_table = result.query_job.destination assert result_table is not None if temp_table_ref: @@ -4477,13 +4485,17 @@ def to_parquet( index=index and self._has_index, ordering_id=bigframes.session._io.bigquery.IO_ORDERING_ID, ) - query_job = self._session._executor.export_gcs( + result = self._session._executor.execute( export_array.rename_columns(id_overrides), - path, - format="parquet", - export_options=export_options, + ex_spec.ExecutionSpec( + ex_spec.GcsOutputSpec( + uri=path, + format="parquet", + export_options=tuple(export_options.items()), + ) + ), ) - self._set_internal_query_job(query_job) + self._set_internal_query_job(result.query_job) return None def to_dict( diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index a970e75a0f..b428cd646c 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -14,11 +14,9 @@ from __future__ import annotations -import dataclasses import math -import os import threading -from typing import cast, Literal, Mapping, Optional, Sequence, Tuple, Union +from typing import Literal, Mapping, Optional, Sequence, Tuple import warnings import weakref @@ -35,12 +33,12 @@ from bigframes.core import compile, local_data, rewrite import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir import bigframes.core.guid +import bigframes.core.identifiers import bigframes.core.nodes as nodes import bigframes.core.ordering as order import bigframes.core.schema as schemata import bigframes.core.tree_properties as tree_properties import bigframes.dtypes -import bigframes.features from bigframes.session import ( executor, loader, @@ -49,6 +47,7 @@ semi_executor, ) import bigframes.session._io.bigquery as bq_io +import bigframes.session.execution_spec as ex_spec import bigframes.session.metrics import bigframes.session.planner import bigframes.session.temporary_storage @@ -61,21 +60,6 @@ MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G -@dataclasses.dataclass -class OutputSpec: - require_bq_table: bool - cluster_cols: tuple[str, ...] - - def with_require_table(self, value: bool) -> OutputSpec: - return dataclasses.replace(self, require_bq_table=value) - - -def _get_default_output_spec() -> OutputSpec: - return OutputSpec( - require_bq_table=bigframes.options._allow_large_results, cluster_cols=() - ) - - SourceIdMapping = Mapping[str, str] @@ -189,7 +173,11 @@ def to_sql( ) -> str: if offset_column: array_value, _ = array_value.promote_offsets() - node = self.logical_plan(array_value.node) if enable_cache else array_value.node + node = ( + self.prepare_plan(array_value.node, target="simplify") + if enable_cache + else array_value.node + ) node = self._substitute_large_local_sources(node) compiled = compile.compile_sql(compile.CompileRequest(node, sort_rows=ordered)) return compiled.sql @@ -197,86 +185,113 @@ def to_sql( def execute( self, array_value: bigframes.core.ArrayValue, - *, - ordered: bool = True, - use_explicit_destination: Optional[bool] = None, + execution_spec: ex_spec.ExecutionSpec, ) -> executor.ExecuteResult: - if bigframes.options.compute.enable_multi_query_execution: - self._simplify_with_caching(array_value) - - output_spec = _get_default_output_spec() - if use_explicit_destination is not None: - output_spec = output_spec.with_require_table(use_explicit_destination) - - plan = self.logical_plan(array_value.node) - return self._execute_plan( - plan, - ordered=ordered, - output_spec=output_spec, - ) + # TODO: Support export jobs in combination with semi executors + if execution_spec.destination_spec is None: + plan = self.prepare_plan(array_value.node, target="simplify") + for exec in self._semi_executors: + maybe_result = exec.execute( + plan, ordered=execution_spec.ordered, peek=execution_spec.peek + ) + if maybe_result: + return maybe_result - def peek( - self, - array_value: bigframes.core.ArrayValue, - n_rows: int, - use_explicit_destination: Optional[bool] = None, - ) -> executor.ExecuteResult: - """ - A 'peek' efficiently accesses a small number of rows in the dataframe. - """ - plan = self.logical_plan(array_value.node) - if not tree_properties.can_fast_peek(plan): - msg = bfe.format_message("Peeking this value cannot be done efficiently.") - warnings.warn(msg) + if isinstance(execution_spec.destination_spec, ex_spec.TableOutputSpec): + if execution_spec.peek or execution_spec.ordered: + raise NotImplementedError( + "Ordering and peeking not supported for gbq export" + ) + # separate path for export_gbq, as it has all sorts of annoying logic, such as possibly running as dml + return self._export_gbq(array_value, execution_spec.destination_spec) + + result = self._execute_plan_gbq( + array_value.node, + ordered=execution_spec.ordered, + peek=execution_spec.peek, + cache_spec=execution_spec.destination_spec + if isinstance(execution_spec.destination_spec, ex_spec.CacheSpec) + else None, + must_create_table=not execution_spec.promise_under_10gb, + ) + # post steps: export + if isinstance(execution_spec.destination_spec, ex_spec.GcsOutputSpec): + self._export_result_gcs(result, execution_spec.destination_spec) - output_spec = _get_default_output_spec() - if use_explicit_destination is not None: - output_spec = output_spec.with_require_table(use_explicit_destination) + return result - return self._execute_plan( - plan, ordered=False, output_spec=output_spec, peek=n_rows + def _export_result_gcs( + self, result: executor.ExecuteResult, gcs_export_spec: ex_spec.GcsOutputSpec + ): + query_job = result.query_job + assert query_job is not None + result_table = query_job.destination + assert result_table is not None + export_data_statement = bq_io.create_export_data_statement( + f"{result_table.project}.{result_table.dataset_id}.{result_table.table_id}", + uri=gcs_export_spec.uri, + format=gcs_export_spec.format, + export_options=dict(gcs_export_spec.export_options), + ) + bq_io.start_query_with_client( + self.bqclient, + export_data_statement, + job_config=bigquery.QueryJobConfig(), + metrics=self.metrics, + project=None, + location=None, + timeout=None, + query_with_job=True, ) - def export_gbq( - self, - array_value: bigframes.core.ArrayValue, - destination: bigquery.TableReference, - if_exists: Literal["fail", "replace", "append"] = "fail", - cluster_cols: Sequence[str] = [], - ): + def _maybe_find_existing_table( + self, spec: ex_spec.TableOutputSpec + ) -> Optional[bigquery.Table]: + # validate destination table + try: + table = self.bqclient.get_table(spec.table) + if spec.if_exists == "fail": + raise ValueError(f"Table already exists: {spec.table.__str__()}") + + if len(spec.cluster_cols) != 0: + if (table.clustering_fields is None) or ( + tuple(table.clustering_fields) != spec.cluster_cols + ): + raise ValueError( + "Table clustering fields cannot be changed after the table has " + f"been created. Requested clustering fields: {spec.cluster_cols}, existing clustering fields: {table.clustering_fields}" + ) + return table + except google.api_core.exceptions.NotFound: + return None + + def _export_gbq( + self, array_value: bigframes.core.ArrayValue, spec: ex_spec.TableOutputSpec + ) -> executor.ExecuteResult: """ Export the ArrayValue to an existing BigQuery table. """ - if bigframes.options.compute.enable_multi_query_execution: - self._simplify_with_caching(array_value) + plan = self.prepare_plan(array_value.node, target="bq_execution") - table_exists = True - try: - table = self.bqclient.get_table(destination) - if if_exists == "fail": - raise ValueError(f"Table already exists: {destination.__str__()}") - except google.api_core.exceptions.NotFound: - table_exists = False + # validate destination table + existing_table = self._maybe_find_existing_table(spec) - if len(cluster_cols) != 0: - if table_exists and table.clustering_fields != cluster_cols: - raise ValueError( - "Table clustering fields cannot be changed after the table has " - f"been created. Existing clustering fields: {table.clustering_fields}" - ) + compiled = compile.compile_sql(compile.CompileRequest(plan, sort_rows=False)) + sql = compiled.sql - sql = self.to_sql(array_value, ordered=False) - if table_exists and _if_schema_match(table.schema, array_value.schema): + if (existing_table is not None) and _if_schema_match( + existing_table.schema, array_value.schema + ): # b/409086472: Uses DML for table appends and replacements to avoid # BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits: # https://cloud.google.com/bigquery/quotas#standard_tables job_config = bigquery.QueryJobConfig() ir = sqlglot_ir.SQLGlotIR.from_query_string(sql) - if if_exists == "append": - sql = ir.insert(destination) + if spec.if_exists == "append": + sql = ir.insert(spec.table) else: # for "replace" - assert if_exists == "replace" - sql = ir.replace(destination) + assert spec.if_exists == "replace" + sql = ir.replace(spec.table) else: dispositions = { "fail": bigquery.WriteDisposition.WRITE_EMPTY, @@ -284,14 +299,14 @@ def export_gbq( "append": bigquery.WriteDisposition.WRITE_APPEND, } job_config = bigquery.QueryJobConfig( - write_disposition=dispositions[if_exists], - destination=destination, - clustering_fields=cluster_cols if cluster_cols else None, + write_disposition=dispositions[spec.if_exists], + destination=spec.table, + clustering_fields=spec.cluster_cols if spec.cluster_cols else None, ) # TODO(swast): plumb through the api_name of the user-facing api that # caused this query. - _, query_job = self._run_execute_query( + row_iter, query_job = self._run_execute_query( sql=sql, job_config=job_config, ) @@ -300,48 +315,16 @@ def export_gbq( t == bigframes.dtypes.TIMEDELTA_DTYPE for t in array_value.schema.dtypes ) - if if_exists != "append" and has_timedelta_col: + if spec.if_exists != "append" and has_timedelta_col: # Only update schema if this is not modifying an existing table, and the # new table contains timedelta columns. - table = self.bqclient.get_table(destination) + table = self.bqclient.get_table(spec.table) table.schema = array_value.schema.to_bigquery() self.bqclient.update_table(table, ["schema"]) - return query_job - - def export_gcs( - self, - array_value: bigframes.core.ArrayValue, - uri: str, - format: Literal["json", "csv", "parquet"], - export_options: Mapping[str, Union[bool, str]], - ): - query_job = self.execute( - array_value, - ordered=False, - use_explicit_destination=True, - ).query_job - assert query_job is not None - result_table = query_job.destination - assert result_table is not None - export_data_statement = bq_io.create_export_data_statement( - f"{result_table.project}.{result_table.dataset_id}.{result_table.table_id}", - uri=uri, - format=format, - export_options=dict(export_options), - ) - - bq_io.start_query_with_client( - self.bqclient, - export_data_statement, - job_config=bigquery.QueryJobConfig(), - metrics=self.metrics, - project=None, - location=None, - timeout=None, - query_with_job=True, + return executor.ExecuteResult( + row_iter.to_arrow_iterable(), array_value.schema, query_job ) - return query_job def dry_run( self, array_value: bigframes.core.ArrayValue, ordered: bool = True @@ -446,59 +429,56 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): # Once rewriting is available, will want to rewrite before # evaluating execution cost. return tree_properties.is_trivially_executable( - self.logical_plan(array_value.node) + self.prepare_plan(array_value.node) ) - def logical_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode: + def prepare_plan( + self, + plan: nodes.BigFrameNode, + target: Literal["simplify", "bq_execution"] = "simplify", + ) -> nodes.BigFrameNode: """ - Apply universal logical simplifications that are helpful regardless of engine. + Prepare the plan by simplifying it with caches, removing unused operators. Has modes for different contexts. + + "simplify" removes unused operations and subsitutes subtrees with their previously cached equivalents + "bq_execution" is the most heavy option, preparing the plan for bq execution by also caching subtrees, uploading large local sources """ - plan = self.replace_cached_subtrees(root) + # TODO: We should model plan decomposition and data uploading as work steps rather than as plan preparation. + if ( + target == "bq_execution" + and bigframes.options.compute.enable_multi_query_execution + ): + self._simplify_with_caching(plan) + + plan = self.replace_cached_subtrees(plan) plan = rewrite.column_pruning(plan) plan = plan.top_down(rewrite.fold_row_counts) + + if target == "bq_execution": + plan = self._substitute_large_local_sources(plan) + return plan def _cache_with_cluster_cols( self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str] ): """Executes the query and uses the resulting table to rewrite future executions.""" - plan = self.logical_plan(array_value.node) - plan = self._substitute_large_local_sources(plan) - compiled = compile.compile_sql( - compile.CompileRequest( - plan, sort_rows=False, materialize_all_order_keys=True - ) - ) - tmp_table_ref, num_rows = self._sql_as_cached_temp_table( - compiled.sql, - compiled.sql_schema, - cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols), + execution_spec = ex_spec.ExecutionSpec( + destination_spec=ex_spec.CacheSpec(cluster_cols=tuple(cluster_cols)) ) - tmp_table = self.bqclient.get_table(tmp_table_ref) - assert compiled.row_order is not None - self.cache.cache_results_table( - array_value.node, tmp_table, compiled.row_order, num_rows=num_rows + self.execute( + array_value, + execution_spec=execution_spec, ) def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): """Executes the query and uses the resulting table to rewrite future executions.""" - offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") - w_offsets, offset_column = array_value.promote_offsets() - compiled = compile.compile_sql( - compile.CompileRequest( - self.logical_plan(self._substitute_large_local_sources(w_offsets.node)), - sort_rows=False, - ) + execution_spec = ex_spec.ExecutionSpec( + destination_spec=ex_spec.CacheSpec(cluster_cols=tuple()) ) - tmp_table_ref, num_rows = self._sql_as_cached_temp_table( - compiled.sql, - compiled.sql_schema, - cluster_cols=[offset_column], - ) - tmp_table = self.bqclient.get_table(tmp_table_ref) - assert compiled.row_order is not None - self.cache.cache_results_table( - array_value.node, tmp_table, compiled.row_order, num_rows=num_rows + self.execute( + array_value, + execution_spec=execution_spec, ) def _cache_with_session_awareness( @@ -520,17 +500,17 @@ def _cache_with_session_awareness( else: self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), []) - def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): + def _simplify_with_caching(self, plan: nodes.BigFrameNode): """Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces.""" # Apply existing caching first for _ in range(MAX_SUBTREE_FACTORINGS): if ( - self.logical_plan(array_value.node).planning_complexity + self.prepare_plan(plan, "simplify").planning_complexity < QUERY_COMPLEXITY_LIMIT ): return - did_cache = self._cache_most_complex_subtree(array_value.node) + did_cache = self._cache_most_complex_subtree(plan) if not did_cache: return @@ -552,52 +532,6 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool: self._cache_with_cluster_cols(bigframes.core.ArrayValue(selection), []) return True - def _sql_as_cached_temp_table( - self, - sql: str, - schema: Sequence[bigquery.SchemaField], - cluster_cols: Sequence[str], - ) -> tuple[bigquery.TableReference, Optional[int]]: - assert len(cluster_cols) <= _MAX_CLUSTER_COLUMNS - temp_table = self.storage_manager.create_temp_table(schema, cluster_cols) - - # TODO: Get default job config settings - job_config = cast( - bigquery.QueryJobConfig, - bigquery.QueryJobConfig.from_api_repr({}), - ) - job_config.destination = temp_table - _, query_job = self._run_execute_query( - sql, - job_config=job_config, - ) - assert query_job is not None - iter = query_job.result() - return query_job.destination, iter.total_rows - - def _validate_result_schema( - self, - array_value: bigframes.core.ArrayValue, - bq_schema: list[bigquery.SchemaField], - ): - actual_schema = _sanitize(tuple(bq_schema)) - ibis_schema = compile.test_only_ibis_inferred_schema( - self.logical_plan(array_value.node) - ).to_bigquery() - internal_schema = _sanitize(array_value.schema.to_bigquery()) - if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable: - return - - if internal_schema != actual_schema: - raise ValueError( - f"This error should only occur while testing. BigFrames internal schema: {internal_schema} does not match actual schema: {actual_schema}" - ) - - if ibis_schema != actual_schema: - raise ValueError( - f"This error should only occur while testing. Ibis schema: {ibis_schema} does not match actual schema: {actual_schema}" - ) - def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode): """ Replace large local sources with the uploaded version of those datasources. @@ -646,52 +580,80 @@ def _upload_local_data(self, local_table: local_data.ManagedArrowTable): ) self.cache.cache_remote_replacement(local_table, uploaded) - def _execute_plan( + def _execute_plan_gbq( self, plan: nodes.BigFrameNode, ordered: bool, - output_spec: OutputSpec, peek: Optional[int] = None, + cache_spec: Optional[ex_spec.CacheSpec] = None, + must_create_table: bool = True, ) -> executor.ExecuteResult: """Just execute whatever plan as is, without further caching or decomposition.""" - # First try to execute fast-paths - if not output_spec.require_bq_table: - for exec in self._semi_executors: - maybe_result = exec.execute(plan, ordered=ordered, peek=peek) - if maybe_result: - return maybe_result + # TODO(swast): plumb through the api_name of the user-facing api that + # caused this query. + + og_plan = plan + og_schema = plan.schema + + plan = self.prepare_plan(plan, target="bq_execution") + create_table = must_create_table + cluster_cols: Sequence[str] = [] + if cache_spec is not None: + if peek is not None: + raise ValueError("peek is not compatible with caching.") + + create_table = True + if not cache_spec.cluster_cols: + assert len(cache_spec.cluster_cols) <= _MAX_CLUSTER_COLUMNS + offsets_id = bigframes.core.identifiers.ColumnId( + bigframes.core.guid.generate_guid() + ) + plan = nodes.PromoteOffsetsNode(plan, offsets_id) + cluster_cols = [offsets_id.sql] + else: + cluster_cols = cache_spec.cluster_cols - # Use explicit destination to avoid 10GB limit of temporary table - destination_table = ( - self.storage_manager.create_temp_table( - plan.schema.to_bigquery(), cluster_cols=output_spec.cluster_cols + compiled = compile.compile_sql( + compile.CompileRequest( + plan, + sort_rows=ordered, + peek_count=peek, + materialize_all_order_keys=(cache_spec is not None), ) - if output_spec.require_bq_table - else None ) + # might have more columns than og schema, for hidden ordering columns + compiled_schema = compiled.sql_schema + + destination_table: Optional[bigquery.TableReference] = None - # TODO(swast): plumb through the api_name of the user-facing api that - # caused this query. job_config = bigquery.QueryJobConfig() - # Use explicit destination to avoid 10GB limit of temporary table - if destination_table is not None: + if create_table: + destination_table = self.storage_manager.create_temp_table( + compiled_schema, cluster_cols + ) job_config.destination = destination_table - plan = self._substitute_large_local_sources(plan) - compiled = compile.compile_sql( - compile.CompileRequest(plan, sort_rows=ordered, peek_count=peek) - ) iterator, query_job = self._run_execute_query( sql=compiled.sql, job_config=job_config, query_with_job=(destination_table is not None), ) - if query_job: - size_bytes = self.bqclient.get_table(query_job.destination).num_bytes + table_info: Optional[bigquery.Table] = None + if query_job and query_job.destination: + table_info = self.bqclient.get_table(query_job.destination) + size_bytes = table_info.num_bytes else: size_bytes = None + # we could actually cache even when caching is not explicitly requested, but being conservative for now + if cache_spec is not None: + assert table_info is not None + assert compiled.row_order is not None + self.cache.cache_results_table( + og_plan, table_info, compiled.row_order, num_rows=table_info.num_rows + ) + if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES: msg = bfe.format_message( "The query result size has exceeded 10 GB. In BigFrames 2.0 and " @@ -700,18 +662,12 @@ def _execute_plan( "`bigframes.options.compute.allow_large_results=True`." ) warnings.warn(msg, FutureWarning) - # Runs strict validations to ensure internal type predictions and ibis are completely in sync - # Do not execute these validations outside of testing suite. - if "PYTEST_CURRENT_TEST" in os.environ: - self._validate_result_schema( - bigframes.core.ArrayValue(plan), iterator.schema - ) return executor.ExecuteResult( _arrow_batches=iterator.to_arrow_iterable( bqstorage_client=self.bqstoragereadclient ), - schema=plan.schema, + schema=og_schema, query_job=query_job, total_bytes=size_bytes, total_rows=iterator.total_rows, @@ -731,19 +687,3 @@ def _if_schema_match( ): return False return True - - -def _sanitize( - schema: Tuple[bigquery.SchemaField, ...] -) -> Tuple[bigquery.SchemaField, ...]: - # Schema inferred from SQL strings and Ibis expressions contain only names, types and modes, - # so we disregard other fields (e.g timedelta description for timedelta columns) for validations. - return tuple( - bigquery.SchemaField( - f.name, - f.field_type, - f.mode, # type:ignore - fields=_sanitize(f.fields), - ) - for f in schema - ) diff --git a/bigframes/session/execution_spec.py b/bigframes/session/execution_spec.py new file mode 100644 index 0000000000..c9431dbd11 --- /dev/null +++ b/bigframes/session/execution_spec.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Literal, Optional, Union + +from google.cloud import bigquery + + +@dataclasses.dataclass(frozen=True) +class ExecutionSpec: + destination_spec: Union[TableOutputSpec, GcsOutputSpec, CacheSpec, None] = None + peek: Optional[int] = None + ordered: bool = ( + False # ordered and promise_under_10gb must both be together for bq execution + ) + # This is an optimization flag for gbq execution, it doesn't change semantics, but if promise is falsely made, errors may occur + promise_under_10gb: bool = False + + +# This one is temporary, in future, caching will not be done through immediate execution, but will label nodes +# that will be cached only when a super-tree is executed +@dataclasses.dataclass(frozen=True) +class CacheSpec: + cluster_cols: tuple[str, ...] + + +@dataclasses.dataclass(frozen=True) +class TableOutputSpec: + table: bigquery.TableReference + cluster_cols: tuple[str, ...] + if_exists: Literal["fail", "replace", "append"] = "fail" + + +@dataclasses.dataclass(frozen=True) +class GcsOutputSpec: + uri: str + format: Literal["json", "csv", "parquet"] + # sequence of (option, value) pairs + export_options: tuple[tuple[str, Union[bool, str]], ...] diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index cc8f086f9f..748b10647a 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -18,7 +18,7 @@ import dataclasses import functools import itertools -from typing import Iterator, Literal, Mapping, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Union from google.cloud import bigquery import pandas as pd @@ -29,6 +29,7 @@ from bigframes.core import pyarrow_utils import bigframes.core.schema import bigframes.session._io.pandas as io_pandas +import bigframes.session.execution_spec as ex_spec _ROW_LIMIT_EXCEEDED_TEMPLATE = ( "Execution has downloaded {result_rows} rows so far, which exceeds the " @@ -147,41 +148,16 @@ def to_sql( """ raise NotImplementedError("to_sql not implemented for this executor") + @abc.abstractmethod def execute( self, array_value: bigframes.core.ArrayValue, - *, - ordered: bool = True, - use_explicit_destination: Optional[bool] = False, + execution_spec: ex_spec.ExecutionSpec, ) -> ExecuteResult: """ - Execute the ArrayValue, storing the result to a temporary session-owned table. - """ - raise NotImplementedError("execute not implemented for this executor") - - def export_gbq( - self, - array_value: bigframes.core.ArrayValue, - destination: bigquery.TableReference, - if_exists: Literal["fail", "replace", "append"] = "fail", - cluster_cols: Sequence[str] = [], - ) -> bigquery.QueryJob: - """ - Export the ArrayValue to an existing BigQuery table. + Execute the ArrayValue. """ - raise NotImplementedError("export_gbq not implemented for this executor") - - def export_gcs( - self, - array_value: bigframes.core.ArrayValue, - uri: str, - format: Literal["json", "csv", "parquet"], - export_options: Mapping[str, Union[bool, str]], - ) -> bigquery.QueryJob: - """ - Export the ArrayValue to gcs. - """ - raise NotImplementedError("export_gcs not implemented for this executor") + ... def dry_run( self, array_value: bigframes.core.ArrayValue, ordered: bool = True @@ -193,17 +169,6 @@ def dry_run( """ raise NotImplementedError("dry_run not implemented for this executor") - def peek( - self, - array_value: bigframes.core.ArrayValue, - n_rows: int, - use_explicit_destination: Optional[bool] = False, - ) -> ExecuteResult: - """ - A 'peek' efficiently accesses a small number of rows in the dataframe. - """ - raise NotImplementedError("peek not implemented for this executor") - def cached( self, array_value: bigframes.core.ArrayValue, diff --git a/bigframes/testing/compiler_session.py b/bigframes/testing/compiler_session.py index 35114d95d0..289b2600fd 100644 --- a/bigframes/testing/compiler_session.py +++ b/bigframes/testing/compiler_session.py @@ -41,3 +41,10 @@ def to_sql( return self.compiler.SQLGlotCompiler().compile( array_value.node, ordered=ordered ) + + def execute( + self, + array_value, + execution_spec, + ): + raise NotImplementedError("SQLCompilerExecutor.execute not implemented") diff --git a/bigframes/testing/polars_session.py b/bigframes/testing/polars_session.py index 3710c40eae..29eae20b7a 100644 --- a/bigframes/testing/polars_session.py +++ b/bigframes/testing/polars_session.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import Optional, Union +from typing import Union import weakref import pandas @@ -23,48 +23,31 @@ import bigframes.core.blocks import bigframes.core.compile.polars import bigframes.dataframe +import bigframes.session.execution_spec import bigframes.session.executor import bigframes.session.metrics -# Does not support to_sql, export_gbq, export_gcs, dry_run, peek, head, get_row_count, cached +# Does not support to_sql, dry_run, peek, cached @dataclasses.dataclass class TestExecutor(bigframes.session.executor.Executor): compiler = bigframes.core.compile.polars.PolarsCompiler() - def peek( - self, - array_value: bigframes.core.ArrayValue, - n_rows: int, - use_explicit_destination: Optional[bool] = False, - ): - """ - A 'peek' efficiently accesses a small number of rows in the dataframe. - """ - lazy_frame: polars.LazyFrame = self.compiler.compile(array_value.node) - pa_table = lazy_frame.collect().limit(n_rows).to_arrow() - # Currently, pyarrow types might not quite be exactly the ones in the bigframes schema. - # Nullability may be different, and might use large versions of list, string datatypes. - return bigframes.session.executor.ExecuteResult( - _arrow_batches=pa_table.to_batches(), - schema=array_value.schema, - total_bytes=pa_table.nbytes, - total_rows=pa_table.num_rows, - ) - def execute( self, array_value: bigframes.core.ArrayValue, - *, - ordered: bool = True, - use_explicit_destination: Optional[bool] = False, - page_size: Optional[int] = None, - max_results: Optional[int] = None, + execution_spec: bigframes.session.execution_spec.ExecutionSpec, ): """ Execute the ArrayValue, storing the result to a temporary session-owned table. """ + if execution_spec.destination_spec is not None: + raise ValueError( + f"TestExecutor does not support destination spec: {execution_spec.destination_spec}" + ) lazy_frame: polars.LazyFrame = self.compiler.compile(array_value.node) + if execution_spec.peek is not None: + lazy_frame = lazy_frame.limit(execution_spec.peek) pa_table = lazy_frame.collect().to_arrow() # Currently, pyarrow types might not quite be exactly the ones in the bigframes schema. # Nullability may be different, and might use large versions of list, string datatypes. diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 6343f0cc53..892f8c8898 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -36,6 +36,7 @@ import bigframes.dataframe import bigframes.dtypes import bigframes.ml.linear_model +import bigframes.session.execution_spec from bigframes.testing import utils all_write_engines = pytest.mark.parametrize( @@ -113,7 +114,10 @@ def test_read_gbq_tokyo( # use_explicit_destination=True, otherwise might use path with no query_job exec_result = session_tokyo._executor.execute( - df._block.expr, use_explicit_destination=True + df._block.expr, + bigframes.session.execution_spec.ExecutionSpec( + bigframes.session.execution_spec.CacheSpec(()), promise_under_10gb=False + ), ) assert exec_result.query_job is not None assert exec_result.query_job.location == tokyo_location @@ -896,7 +900,10 @@ def test_read_pandas_tokyo( expected = scalars_pandas_df_index result = session_tokyo._executor.execute( - df._block.expr, use_explicit_destination=True + df._block.expr, + bigframes.session.execution_spec.ExecutionSpec( + bigframes.session.execution_spec.CacheSpec(()), promise_under_10gb=False + ), ) assert result.query_job is not None assert result.query_job.location == tokyo_location From 6370d3bdd8ac997884b44ca7d8e06773b70947c5 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 2 Sep 2025 15:17:56 -0700 Subject: [PATCH 26/28] chore: refactor test_unary_compiler to apply multiple ops (#2043) --- .../test_endswith/no_pattern.sql | 13 - .../{multiple_patterns.sql => out.sql} | 8 +- .../test_endswith/single_pattern.sql | 13 - .../test_startswith/no_pattern.sql | 13 - .../{multiple_patterns.sql => out.sql} | 8 +- .../test_startswith/single_pattern.sql | 13 - .../test_unary_compiler/test_str_find/out.sql | 10 +- .../test_str_find/out_with_end.sql | 13 - .../test_str_find/out_with_start.sql | 13 - .../test_str_find/out_with_start_and_end.sql | 13 - .../test_unary_compiler/test_str_pad/left.sql | 13 - .../test_str_pad/{both.sql => out.sql} | 8 +- .../test_str_pad/right.sql | 13 - .../test_struct_field/out.sql | 6 +- .../expressions/test_unary_compiler.py | 671 +++++++++++------- 15 files changed, 444 insertions(+), 384 deletions(-) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql rename tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/{multiple_patterns.sql => out.sql} (62%) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql rename tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/{multiple_patterns.sql => out.sql} (61%) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql rename tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/{both.sql => out.sql} (63%) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql deleted file mode 100644 index e9f61ddd7c..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - FALSE AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/out.sql similarity index 62% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/out.sql index f224471e79..e3ac5ec033 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/out.sql @@ -5,9 +5,13 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + ENDS_WITH(`bfcol_0`, 'ab') AS `bfcol_1`, + ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_2`, + FALSE AS `bfcol_3` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `string_col` + `bfcol_1` AS `single`, + `bfcol_2` AS `double`, + `bfcol_3` AS `empty` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql deleted file mode 100644 index a4e259f0b2..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - ENDS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql deleted file mode 100644 index e9f61ddd7c..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - FALSE AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/out.sql similarity index 61% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/out.sql index 061b57e208..9679c95f75 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/out.sql @@ -5,9 +5,13 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - STARTS_WITH(`bfcol_0`, 'ab') OR STARTS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + STARTS_WITH(`bfcol_0`, 'ab') AS `bfcol_1`, + STARTS_WITH(`bfcol_0`, 'ab') OR STARTS_WITH(`bfcol_0`, 'cd') AS `bfcol_2`, + FALSE AS `bfcol_3` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `string_col` + `bfcol_1` AS `single`, + `bfcol_2` AS `double`, + `bfcol_3` AS `empty` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql deleted file mode 100644 index 726ce05b8c..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - STARTS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql index dfc100e413..b850262d80 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out.sql @@ -5,9 +5,15 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1` + INSTR(`bfcol_0`, 'e', 1) - 1 AS `bfcol_1`, + INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_2`, + INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_3`, + INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_4` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `string_col` + `bfcol_1` AS `none_none`, + `bfcol_2` AS `start_none`, + `bfcol_3` AS `none_end`, + `bfcol_4` AS `start_end` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql deleted file mode 100644 index 78edf662b9..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_end.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - INSTR(SUBSTRING(`bfcol_0`, 1, 5), 'e') - 1 AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql deleted file mode 100644 index d0dfc11a53..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - INSTR(`bfcol_0`, 'e', 3) - 1 AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql deleted file mode 100644 index a91ab32946..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_find/out_with_start_and_end.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - INSTR(SUBSTRING(`bfcol_0`, 3, 3), 'e') - 1 AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql deleted file mode 100644 index ee95900b3e..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/left.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/out.sql similarity index 63% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/out.sql index 4701b0237a..4226843122 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/both.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/out.sql @@ -5,6 +5,8 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, + LPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1`, + RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_2`, RPAD( LPAD( `bfcol_0`, @@ -13,9 +15,11 @@ WITH `bfcte_0` AS ( ), GREATEST(LENGTH(`bfcol_0`), 10), '-' - ) AS `bfcol_1` + ) AS `bfcol_3` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `string_col` + `bfcol_1` AS `left`, + `bfcol_2` AS `right`, + `bfcol_3` AS `both` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql deleted file mode 100644 index 17e59c553f..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_str_pad/right.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `string_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - RPAD(`bfcol_0`, GREATEST(LENGTH(`bfcol_0`), 10), '-') AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `string_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_struct_field/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_struct_field/out.sql index b3e8fde0b2..60ae78b755 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_struct_field/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_struct_field/out.sql @@ -5,9 +5,11 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - `bfcol_0`.`name` AS `bfcol_1` + `bfcol_0`.`name` AS `bfcol_1`, + `bfcol_0`.`name` AS `bfcol_2` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `people` + `bfcol_1` AS `string`, + `bfcol_2` AS `int` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 8f3af11842..815bb84a9a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -12,437 +12,525 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import pytest from bigframes import operations as ops +from bigframes.core import expression as expr from bigframes.operations._op_converters import convert_index, convert_slice import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") -def _apply_unary_op(obj: bpd.DataFrame, op: ops.UnaryOp, arg: str) -> str: +def _apply_unary_ops( + obj: bpd.DataFrame, + ops_list: typing.Sequence[expr.Expression], + new_names: typing.Sequence[str], +) -> str: array_value = obj._block.expr - op_expr = op.as_expr(arg) - result, col_ids = array_value.compute_values([op_expr]) + result, old_names = array_value.compute_values(ops_list) # Rename columns for deterministic golden SQL results. - assert len(col_ids) == 1 - result = result.rename_columns({col_ids[0]: arg}).select_columns([arg]) + assert len(old_names) == len(new_names) + col_ids = {old_name: new_name for old_name, new_name in zip(old_names, new_names)} + result = result.rename_columns(col_ids).select_columns(new_names) sql = result.session._executor.to_sql(result, enable_cache=False) return sql def test_arccosh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arccosh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arccosh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_arccos(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arccos_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arccos_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_arcsin(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arcsin_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arcsin_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_arcsinh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arcsinh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arcsinh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_arctan(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arctan_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arctan_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_arctanh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.arctanh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.arctanh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_abs(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.abs_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.abs_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_capitalize(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.capitalize_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.capitalize_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_ceil(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.ceil_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ceil_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_date(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.date_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.date_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_day(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.day_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.day_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_dayofweek(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.dayofweek_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.dayofweek_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.dayofyear_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.dayofyear_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_endswith(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab",)), "string_col") - snapshot.assert_match(sql, "single_pattern.sql") - - sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab", "cd")), "string_col") - snapshot.assert_match(sql, "multiple_patterns.sql") - - sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=()), "string_col") - snapshot.assert_match(sql, "no_pattern.sql") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + ops_map = { + "single": ops.EndsWithOp(pat=("ab",)).as_expr(col_name), + "double": ops.EndsWithOp(pat=("ab", "cd")).as_expr(col_name), + "empty": ops.EndsWithOp(pat=()).as_expr(col_name), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") def test_exp(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.exp_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.exp_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_expm1(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.expm1_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.expm1_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.FloorDtOp("D"), "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.FloorDtOp("D").as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_floor(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.floor_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.floor_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_geo_area(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_area_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.geo_area_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_geo_st_astext(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_astext_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.geo_st_astext_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_geo_st_boundary(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_boundary_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.geo_st_boundary_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_buffer(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.GeoStBufferOp(1.0, 8.0, False), "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.GeoStBufferOp(1.0, 8.0, False).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_centroid(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_centroid_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.geo_st_centroid_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_convexhull(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_convexhull_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.geo_st_convexhull_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_geogfromtext(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_geogfromtext_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.geo_st_geogfromtext_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_isclosed(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_st_isclosed_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.geo_st_isclosed_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_st_length(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.GeoStLengthOp(True), "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.GeoStLengthOp(True).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_geo_x(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_x_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.geo_x_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_geo_y(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["geography_col"]] - sql = _apply_unary_op(bf_df, ops.geo_y_op, "geography_col") + col_name = "geography_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.geo_y_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot): - bf_df = repeated_types_df[["string_list_col"]] - sql = _apply_unary_op(bf_df, ops.ArrayToStringOp(delimiter="."), "string_list_col") + col_name = "string_list_col" + bf_df = repeated_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.ArrayToStringOp(delimiter=".").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_array_index(repeated_types_df: bpd.DataFrame, snapshot): - bf_df = repeated_types_df[["string_list_col"]] - sql = _apply_unary_op(bf_df, convert_index(1), "string_list_col") + col_name = "string_list_col" + bf_df = repeated_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [convert_index(1).as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_array_slice_with_only_start(repeated_types_df: bpd.DataFrame, snapshot): - bf_df = repeated_types_df[["string_list_col"]] - sql = _apply_unary_op(bf_df, convert_slice(slice(1, None)), "string_list_col") + col_name = "string_list_col" + bf_df = repeated_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [convert_slice(slice(1, None)).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snapshot): - bf_df = repeated_types_df[["string_list_col"]] - sql = _apply_unary_op(bf_df, convert_slice(slice(1, 5)), "string_list_col") + col_name = "string_list_col" + bf_df = repeated_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [convert_slice(slice(1, 5)).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_cos(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.cos_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.cos_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_cosh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.cosh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.cosh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_hash(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.hash_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.hash_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_hour(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.hour_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.hour_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_invert(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, ops.invert_op, "int64_col") + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.invert_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_is_in(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, ops.IsInOp(values=(1, 2, 3)), "int64_col") + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.IsInOp(values=(1, 2, 3)).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_isalnum(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isalnum_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isalnum_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isalpha(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isalpha_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isalpha_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isdecimal(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isdecimal_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isdecimal_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isdigit(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isdigit_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isdigit_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_islower(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.islower_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.islower_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isnumeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isnumeric_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isnumeric_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isspace(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isspace_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isspace_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isupper(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.isupper_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isupper_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_len(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.len_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.len_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_ln(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.ln_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ln_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_log10(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.log10_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.log10_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_log1p(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.log1p_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.log1p_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_lower(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.lower_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.lower_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_map(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op( - bf_df, ops.MapOp(mappings=(("value1", "mapped1"),)), "string_col" + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, + [ops.MapOp(mappings=(("value1", "mapped1"),)).as_expr(col_name)], + [col_name], ) snapshot.assert_match(sql, "out.sql") def test_lstrip(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrLstripOp(" "), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrLstripOp(" ").as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_minute(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.minute_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.minute_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_month(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.month_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.month_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_neg(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.neg_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.neg_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_normalize(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.normalize_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.normalize_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") @@ -460,257 +548,297 @@ def test_obj_get_access_url(scalar_types_df: bpd.DataFrame, snapshot): def test_pos(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.pos_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.pos_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_quarter(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.quarter_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.quarter_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_replace_str(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.ReplaceStrOp("e", "a"), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.ReplaceStrOp("e", "a").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_regex_replace_str(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.RegexReplaceStrOp(r"e", "a"), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.RegexReplaceStrOp(r"e", "a").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_reverse(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.reverse_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.reverse_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_second(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.second_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.second_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_rstrip(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrRstripOp(" "), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrRstripOp(" ").as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_sqrt(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.sqrt_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.sqrt_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_startswith(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab",)), "string_col") - snapshot.assert_match(sql, "single_pattern.sql") - - sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab", "cd")), "string_col") - snapshot.assert_match(sql, "multiple_patterns.sql") - sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=()), "string_col") - snapshot.assert_match(sql, "no_pattern.sql") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + ops_map = { + "single": ops.StartsWithOp(pat=("ab",)).as_expr(col_name), + "double": ops.StartsWithOp(pat=("ab", "cd")).as_expr(col_name), + "empty": ops.StartsWithOp(pat=()).as_expr(col_name), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") def test_str_get(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrGetOp(1), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrGetOp(1).as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_str_pad(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op( - bf_df, ops.StrPadOp(length=10, fillchar="-", side="left"), "string_col" - ) - snapshot.assert_match(sql, "left.sql") - - sql = _apply_unary_op( - bf_df, ops.StrPadOp(length=10, fillchar="-", side="right"), "string_col" - ) - snapshot.assert_match(sql, "right.sql") - - sql = _apply_unary_op( - bf_df, ops.StrPadOp(length=10, fillchar="-", side="both"), "string_col" - ) - snapshot.assert_match(sql, "both.sql") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + ops_map = { + "left": ops.StrPadOp(length=10, fillchar="-", side="left").as_expr(col_name), + "right": ops.StrPadOp(length=10, fillchar="-", side="right").as_expr(col_name), + "both": ops.StrPadOp(length=10, fillchar="-", side="both").as_expr(col_name), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") def test_str_slice(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrSliceOp(1, 3), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrSliceOp(1, 3).as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_strftime(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.StrftimeOp("%Y-%m-%d"), "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.StrftimeOp("%Y-%m-%d").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot): - bf_df = nested_structs_types_df[["people"]] + col_name = "people" + bf_df = nested_structs_types_df[[col_name]] - # When a name string is provided. - sql = _apply_unary_op(bf_df, ops.StructFieldOp("name"), "people") - snapshot.assert_match(sql, "out.sql") + ops_map = { + # When a name string is provided. + "string": ops.StructFieldOp("name").as_expr(col_name), + # When an index integer is provided. + "int": ops.StructFieldOp(0).as_expr(col_name), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) - # When an index integer is provided. - sql = _apply_unary_op(bf_df, ops.StructFieldOp(0), "people") snapshot.assert_match(sql, "out.sql") def test_str_contains(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrContainsOp("e"), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.StrContainsOp("e").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrContainsRegexOp("e"), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.StrContainsRegexOp("e").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_str_extract(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrExtractOp(r"([a-z]*)", 1), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.StrExtractOp(r"([a-z]*)", 1).as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_str_repeat(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrRepeatOp(2), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrRepeatOp(2).as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_str_find(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=None), "string_col") - snapshot.assert_match(sql, "out.sql") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + ops_map = { + "none_none": ops.StrFindOp("e", start=None, end=None).as_expr(col_name), + "start_none": ops.StrFindOp("e", start=2, end=None).as_expr(col_name), + "none_end": ops.StrFindOp("e", start=None, end=5).as_expr(col_name), + "start_end": ops.StrFindOp("e", start=2, end=5).as_expr(col_name), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) - sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=None), "string_col") - snapshot.assert_match(sql, "out_with_start.sql") - - sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=None, end=5), "string_col") - snapshot.assert_match(sql, "out_with_end.sql") - - sql = _apply_unary_op(bf_df, ops.StrFindOp("e", start=2, end=5), "string_col") - snapshot.assert_match(sql, "out_with_start_and_end.sql") + snapshot.assert_match(sql, "out.sql") def test_strip(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StrStripOp(" "), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.StrStripOp(" ").as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_iso_day(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.iso_day_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.iso_day_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_iso_week(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.iso_week_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.iso_week_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_iso_year(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.iso_year_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.iso_year_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_isnull(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.isnull_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.isnull_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_notnull(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.notnull_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.notnull_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_sin(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.sin_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.sin_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_sinh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.sinh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.sinh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_string_split(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.StringSplitOp(pat=","), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.StringSplitOp(pat=",").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_tan(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.tan_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.tan_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_tanh(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["float64_col"]] - sql = _apply_unary_op(bf_df, ops.tanh_op, "float64_col") + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.tanh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_time(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.time_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.time_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_to_datetime(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, ops.ToDatetimeOp(), "int64_col") + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ToDatetimeOp().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_to_timestamp(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, ops.ToTimestampOp(), "int64_col") + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ToTimestampOp().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") @@ -725,104 +853,133 @@ def test_to_timedelta(scalar_types_df: bpd.DataFrame, snapshot): def test_unix_micros(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.UnixMicros(), "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.UnixMicros().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_unix_millis(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.UnixMillis(), "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.UnixMillis().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_unix_seconds(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.UnixSeconds(), "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.UnixSeconds().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_timedelta_floor(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_unary_op(bf_df, ops.timedelta_floor_op, "int64_col") + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.timedelta_floor_op.as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_extract(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONExtract(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONExtract(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_extract_array(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONExtractArray(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONExtractArray(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_extract_string_array(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONExtractStringArray(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONExtractStringArray(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_query(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONQuery(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONQuery(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_query_array(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONQueryArray(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONQueryArray(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_json_value(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.JSONValue(json_path="$"), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops( + bf_df, [ops.JSONValue(json_path="$").as_expr(col_name)], [col_name] + ) snapshot.assert_match(sql, "out.sql") def test_parse_json(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.ParseJSON(), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ParseJSON().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_to_json_string(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_unary_op(bf_df, ops.ToJSONString(), "json_col") + col_name = "json_col" + bf_df = json_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ToJSONString().as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_upper(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.upper_op, "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.upper_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_year(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col"]] - sql = _apply_unary_op(bf_df, ops.year_op, "timestamp_col") + col_name = "timestamp_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.year_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") def test_zfill(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_unary_op(bf_df, ops.ZfillOp(width=10), "string_col") + col_name = "string_col" + bf_df = scalar_types_df[[col_name]] + sql = _apply_unary_ops(bf_df, [ops.ZfillOp(width=10).as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") From 2d2460650f8e0241d2b860aa915d51122db2509d Mon Sep 17 00:00:00 2001 From: Garrett Wu <6505921+GarrettWu@users.noreply.github.com> Date: Wed, 3 Sep 2025 10:46:13 -0700 Subject: [PATCH 27/28] test: fix blob snippets tests gcs folder wipeout (#2044) --- samples/snippets/conftest.py | 11 +++++++++++ samples/snippets/multimodal_test.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/samples/snippets/conftest.py b/samples/snippets/conftest.py index 81595967ec..e19cfbceb4 100644 --- a/samples/snippets/conftest.py +++ b/samples/snippets/conftest.py @@ -63,6 +63,17 @@ def gcs_bucket(storage_client: storage.Client) -> Generator[str, None, None]: blob.delete() +@pytest.fixture(scope="session") +def gcs_bucket_snippets(storage_client: storage.Client) -> Generator[str, None, None]: + bucket_name = "bigframes_blob_test_snippet_with_data_wipeout" + + yield bucket_name + + bucket = storage_client.get_bucket(bucket_name) + for blob in bucket.list_blobs(): + blob.delete() + + @pytest.fixture(autouse=True) def reset_session() -> None: """An autouse fixture ensuring each sample runs in a fresh session. diff --git a/samples/snippets/multimodal_test.py b/samples/snippets/multimodal_test.py index 1ea6a3f0a6..033fead33e 100644 --- a/samples/snippets/multimodal_test.py +++ b/samples/snippets/multimodal_test.py @@ -13,9 +13,9 @@ # limitations under the License. -def test_multimodal_dataframe(gcs_bucket: str) -> None: +def test_multimodal_dataframe(gcs_bucket_snippets: str) -> None: # destination folder must be in a GCS bucket that the BQ connection service account (default or user provided) has write access to. - dst_bucket = f"gs://{gcs_bucket}" + dst_bucket = f"gs://{gcs_bucket_snippets}" # [START bigquery_dataframes_multimodal_dataframe_create] import bigframes From 88115fadbf366bd7bcfa3b7ddd2cd4e6d4ad15e2 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:07:09 -0700 Subject: [PATCH 28/28] chore(main): release 2.18.0 (#2023) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 32 +++++++++++++++++++++++ bigframes/version.py | 4 +-- third_party/bigframes_vendored/version.py | 4 +-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc4362cc87..433956da3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,38 @@ [1]: https://pypi.org/project/bigframes/#history +## [2.18.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.17.0...v2.18.0) (2025-09-03) + + +### ⚠ BREAKING CHANGES + +* add `allow_large_results` option to `read_gbq_query`, aligning with `bpd.options.compute.allow_large_results` option ([#1935](https://github.com/googleapis/python-bigquery-dataframes/issues/1935)) + +### Features + +* Add `allow_large_results` option to `read_gbq_query`, aligning with `bpd.options.compute.allow_large_results` option ([#1935](https://github.com/googleapis/python-bigquery-dataframes/issues/1935)) ([a7963fe](https://github.com/googleapis/python-bigquery-dataframes/commit/a7963fe57a0e141debf726f0bc7b0e953ebe9634)) +* Add parameter shuffle for ml.model_selection.train_test_split ([#2030](https://github.com/googleapis/python-bigquery-dataframes/issues/2030)) ([2c72c56](https://github.com/googleapis/python-bigquery-dataframes/commit/2c72c56fb5893eb01d5aec6273d11945c9c532c5)) +* Can pivot unordered, unindexed dataframe ([#2040](https://github.com/googleapis/python-bigquery-dataframes/issues/2040)) ([1a0f710](https://github.com/googleapis/python-bigquery-dataframes/commit/1a0f710ac11418fd71ab3373f3f6002fa581b180)) +* Local date accessor execution support ([#2034](https://github.com/googleapis/python-bigquery-dataframes/issues/2034)) ([7ac6fe1](https://github.com/googleapis/python-bigquery-dataframes/commit/7ac6fe16f7f2c09d2efac6ab813ec841c21baef8)) +* Support args in dataframe apply method ([#2026](https://github.com/googleapis/python-bigquery-dataframes/issues/2026)) ([164c481](https://github.com/googleapis/python-bigquery-dataframes/commit/164c4818bc4ff2990dca16b9f22a798f47e0a60b)) +* Support args in series apply method ([#2013](https://github.com/googleapis/python-bigquery-dataframes/issues/2013)) ([d9d725c](https://github.com/googleapis/python-bigquery-dataframes/commit/d9d725cfbc3dca9e66b460cae4084e25162f2acf)) +* Support callable for dataframe mask method ([#2020](https://github.com/googleapis/python-bigquery-dataframes/issues/2020)) ([9d4504b](https://github.com/googleapis/python-bigquery-dataframes/commit/9d4504be310d38b63515d67c0f60d2e48e68c7b5)) +* Support multi-column assignment for DataFrame ([#2028](https://github.com/googleapis/python-bigquery-dataframes/issues/2028)) ([ba0d23b](https://github.com/googleapis/python-bigquery-dataframes/commit/ba0d23b59c44ba5a46ace8182ad0e0cfc703b3ab)) +* Support string matching in local executor ([#2032](https://github.com/googleapis/python-bigquery-dataframes/issues/2032)) ([c0b54f0](https://github.com/googleapis/python-bigquery-dataframes/commit/c0b54f03849ee3115413670e690e68f3ef10f2ec)) + + +### Bug Fixes + +* Fix scalar op lowering tree walk ([#2029](https://github.com/googleapis/python-bigquery-dataframes/issues/2029)) ([935af10](https://github.com/googleapis/python-bigquery-dataframes/commit/935af107ef98837fb2b81d72185d0b6a9e09fbcf)) +* Read_csv fails when check file size for wildcard gcs files ([#2019](https://github.com/googleapis/python-bigquery-dataframes/issues/2019)) ([b0d620b](https://github.com/googleapis/python-bigquery-dataframes/commit/b0d620bbe8227189bbdc2ba5a913b03c70575296)) +* Resolve the validation issue for other arg in dataframe where method ([#2042](https://github.com/googleapis/python-bigquery-dataframes/issues/2042)) ([8689199](https://github.com/googleapis/python-bigquery-dataframes/commit/8689199aa82212ed300fff592097093812e0290e)) + + +### Performance Improvements + +* Improve axis=1 aggregation performance ([#2036](https://github.com/googleapis/python-bigquery-dataframes/issues/2036)) ([fbb2094](https://github.com/googleapis/python-bigquery-dataframes/commit/fbb209468297a8057d9d49c40e425c3bfdeb92bd)) +* Improve iter_nodes_topo performance using Kahn's algorithm ([#2038](https://github.com/googleapis/python-bigquery-dataframes/issues/2038)) ([3961637](https://github.com/googleapis/python-bigquery-dataframes/commit/39616374bba424996ebeb9a12096bfaf22660b44)) + ## [2.17.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.16.0...v2.17.0) (2025-08-22) diff --git a/bigframes/version.py b/bigframes/version.py index b9aa5d1855..78b6498d2d 100644 --- a/bigframes/version.py +++ b/bigframes/version.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.17.0" +__version__ = "2.18.0" # {x-release-please-start-date} -__release_date__ = "2025-08-22" +__release_date__ = "2025-09-03" # {x-release-please-end} diff --git a/third_party/bigframes_vendored/version.py b/third_party/bigframes_vendored/version.py index b9aa5d1855..78b6498d2d 100644 --- a/third_party/bigframes_vendored/version.py +++ b/third_party/bigframes_vendored/version.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.17.0" +__version__ = "2.18.0" # {x-release-please-start-date} -__release_date__ = "2025-08-22" +__release_date__ = "2025-09-03" # {x-release-please-end}