From 97c3fe1b0b0735f9bcd1b57e153728ece2fb3472 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 1 Jul 2025 19:23:24 +0000 Subject: [PATCH] refactor: add _join_condition for all types --- .../core/compile/sqlglot/scalar_compiler.py | 2 +- bigframes/core/compile/sqlglot/sqlglot_ir.py | 87 ++++++++++++++++++- bigframes/dtypes.py | 5 +- .../test_compile_join/out.sql | 3 +- .../test_compile_join_w_on/bool_col/out.sql | 33 +++++++ .../float64_col/out.sql | 33 +++++++ .../test_compile_join_w_on/int64_col/out.sql | 33 +++++++ .../numeric_col/out.sql | 33 +++++++ .../test_compile_join_w_on/string_col/out.sql | 28 ++++++ .../test_compile_join_w_on/time_col/out.sql | 28 ++++++ .../core/compile/sqlglot/test_compile_join.py | 10 +++ 11 files changed, 290 insertions(+), 5 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 0db507b0fa..683dd38c9a 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -39,7 +39,7 @@ def compile_scalar_expression( @compile_scalar_expression.register def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: - return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True)) + return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True)) @compile_scalar_expression.register diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 3b4d7ed0ce..d5902fa6fc 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -491,4 +491,89 @@ def _join_condition( right: typed_expr.TypedExpr, joins_nulls: bool, ) -> typing.Union[sge.EQ, sge.And]: - return sge.EQ(this=left.expr, expression=right.expr) + """Generates a join condition to match pandas's null-handling logic. + + Pandas treats null values as distinct from each other, leading to a + cross-join-like behavior for null keys. In contrast, BigQuery SQL treats + null values as equal, leading to a inner-join-like behavior. + + This function generates the appropriate SQL condition to replicate the + desired pandas behavior in BigQuery. + + Args: + left: The left-side join key. + right: The right-side join key. + joins_nulls: If True, generates complex logic to handle nulls/NaNs. + Otherwise, uses a simple equality check where appropriate. + """ + is_floating_types = ( + left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE + ) + if not is_floating_types and not joins_nulls: + return sge.EQ(this=left.expr, expression=right.expr) + + is_numeric_types = dtypes.is_numeric( + left.dtype, include_bool=False + ) and dtypes.is_numeric(right.dtype, include_bool=False) + if is_numeric_types: + return _join_condition_for_numeric(left, right) + else: + return _join_condition_for_others(left, right) + + +def _join_condition_for_others( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, +) -> sge.And: + """Generates a join condition for non-numeric types to match pandas's + null-handling logic. + """ + left_str = _cast(left.expr, "STRING") + right_str = _cast(right.expr, "STRING") + left_0 = sge.func("COALESCE", left_str, _literal("0", dtypes.STRING_DTYPE)) + left_1 = sge.func("COALESCE", left_str, _literal("1", dtypes.STRING_DTYPE)) + right_0 = sge.func("COALESCE", right_str, _literal("0", dtypes.STRING_DTYPE)) + right_1 = sge.func("COALESCE", right_str, _literal("1", dtypes.STRING_DTYPE)) + return sge.And( + this=sge.EQ(this=left_0, expression=right_0), + expression=sge.EQ(this=left_1, expression=right_1), + ) + + +def _join_condition_for_numeric( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, +) -> sge.And: + """Generates a join condition for non-numeric types to match pandas's + null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't + equal so need to coalesce as well with different constants. + """ + is_floating_types = ( + left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE + ) + left_0 = sge.func("COALESCE", left.expr, _literal(0, left.dtype)) + left_1 = sge.func("COALESCE", left.expr, _literal(1, left.dtype)) + right_0 = sge.func("COALESCE", right.expr, _literal(0, right.dtype)) + right_1 = sge.func("COALESCE", right.expr, _literal(1, right.dtype)) + if not is_floating_types: + return sge.And( + this=sge.EQ(this=left_0, expression=right_0), + expression=sge.EQ(this=left_1, expression=right_1), + ) + + left_2 = sge.If( + this=sge.IsNan(this=left.expr), true=_literal(2, left.dtype), false=left_0 + ) + left_3 = sge.If( + this=sge.IsNan(this=left.expr), true=_literal(3, left.dtype), false=left_1 + ) + right_2 = sge.If( + this=sge.IsNan(this=right.expr), true=_literal(2, right.dtype), false=right_0 + ) + right_3 = sge.If( + this=sge.IsNan(this=right.expr), true=_literal(3, right.dtype), false=right_1 + ) + return sge.And( + this=sge.EQ(this=left_2, expression=right_2), + expression=sge.EQ(this=left_3, expression=right_3), + ) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index b0a31595e5..e41be5efc6 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -340,8 +340,9 @@ def is_json_encoding_type(type_: ExpressionType) -> bool: return type_ != GEO_DTYPE -def is_numeric(type_: ExpressionType) -> bool: - return type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE +def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool: + is_numeric = type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE + return is_numeric if include_bool else is_numeric and type_ != BOOL_DTYPE def is_iterable(type_: ExpressionType) -> bool: diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql index aefaa28dfb..85eab4487a 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -23,7 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` LEFT JOIN `bfcte_3` - ON `bfcol_2` = `bfcol_6` + ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) + AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) ) SELECT `bfcol_3` AS `int64_col`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql new file mode 100644 index 0000000000..a073e35c69 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') + AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql new file mode 100644 index 0000000000..1d04343f31 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `float64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `float64_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) + AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `float64_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql new file mode 100644 index 0000000000..80ec5d19d1 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) + AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `int64_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql new file mode 100644 index 0000000000..22ce6f5b29 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `numeric_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `numeric_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) + AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `numeric_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql new file mode 100644 index 0000000000..5e8d072d46 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql @@ -0,0 +1,28 @@ +WITH `bfcte_1` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `string_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_2`, + `string_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_2` AS `bfcol_4`, + `bfcol_3` AS `bfcol_5` + FROM `bfcte_0` +), `bfcte_3` AS ( + SELECT + * + FROM `bfcte_1` + INNER JOIN `bfcte_2` + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') +) +SELECT + `bfcol_0` AS `rowindex_x`, + `bfcol_1` AS `string_col`, + `bfcol_4` AS `rowindex_y` +FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql new file mode 100644 index 0000000000..b0df619f25 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql @@ -0,0 +1,28 @@ +WITH `bfcte_1` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `time_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_2`, + `time_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_2` AS `bfcol_4`, + `bfcol_3` AS `bfcol_5` + FROM `bfcte_0` +), `bfcte_3` AS ( + SELECT + * + FROM `bfcte_1` + INNER JOIN `bfcte_2` + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') +) +SELECT + `bfcol_0` AS `rowindex_x`, + `bfcol_1` AS `time_col`, + `bfcol_4` AS `rowindex_y` +FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_join.py b/tests/unit/core/compile/sqlglot/test_compile_join.py index a530ed4fc3..ac016eec02 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_join.py +++ b/tests/unit/core/compile/sqlglot/test_compile_join.py @@ -49,3 +49,13 @@ def test_compile_join_w_how(scalar_types_df: bpd.DataFrame): join_sql = left.merge(right, how="cross").sql assert "CROSS JOIN" in join_sql assert "ON" not in join_sql + + +@pytest.mark.parametrize( + ("on"), + ["bool_col", "int64_col", "float64_col", "string_col", "time_col", "numeric_col"], +) +def test_compile_join_w_on(scalar_types_df: bpd.DataFrame, on: str, snapshot): + df = scalar_types_df[["rowindex", on]] + merge = df.merge(df, left_on=on, right_on=on) + snapshot.assert_match(merge.sql, "out.sql")