8000 refactor: add _join_condition for all types by chelsea-lin · Pull Request #1880 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bigframes/core/compile/sqlglot/scalar_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def compile_scalar_expression(

@compile_scalar_expression.register
def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression:
return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True))
return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True))


@compile_scalar_expression.register
Expand Down
87 changes: 86 additions & 1 deletion bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,89 @@ def _join_condition(
right: typed_expr.TypedExpr,
joins_nulls: bool,
) -> typing.Union[sge.EQ, sge.And]:
return sge.EQ(this=left.expr, expression=right.expr)
"""Generates a join condition to match pandas's null-handling logic.

Pandas treats null values as distinct from each other, leading to a
cross-join-like behavior for null keys. In contrast, BigQuery SQL treats
null values as equal, leading to a inner-join-like behavior.

This function generates the appropriate SQL condition to replicate the
desired pandas behavior in BigQuery.

Args:
left: The left-side join key.
right: The right-side join key.
joins_nulls: If True, generates complex logic to handle nulls/NaNs.
Otherwise, uses a simple equality check where appropriate.
"""
is_floating_types = (
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
)
if not is_floating_types and not joins_nulls:
return sge.EQ(this=left.expr, expression=right.expr)

is_numeric_types = dtypes.is_numeric(
left.dtype, include_bool=False
) and dtypes.is_numeric(right.dtype, include_bool=False)
if is_numeric_types:
return _join_condition_for_numeric(left, right)
else:
return _join_condition_for_others(left, right)


def _join_condition_for_others(
left: typed_expr.TypedExpr,
right: typed_expr.TypedExpr,
) -> sge.And:
"""Generates a join condition for non-numeric types to match pandas's
null-handling logic.
"""
left_str = _cast(left.expr, "STRING")
right_str = _cast(right.expr, "STRING")
left_0 = sge.func("COALESCE", left_str, _literal("0", dtypes.STRING_DTYPE))
left_1 = sge.func("COALESCE", left_str, _literal("1", dtypes.STRING_DTYPE))
right_0 = sge.func("COALESCE", right_str, _literal("0", dtypes.STRING_DTYPE))
right_1 = sge.func("COALESCE", right_str, _literal("1", dtypes.STRING_DTYPE))
return sge.And(
this=sge.EQ(this=left_0, expression=right_0),
expression=sge.EQ(this=left_1, expression=right_1),
)


def _join_condition_for_numeric(
left: typed_expr.TypedExpr,
right: typed_expr.TypedExpr,
) -> sge.And:
"""Generates a join condition for non-numeric types to match pandas's
null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't
equal so need to coalesce as well with different constants.
"""
is_floating_types = (
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
)
left_0 = sge.func("COALESCE", left.expr, _literal(0, left.dtype))
left_1 = sge.func("COALESCE", left.expr, _literal(1, left.dtype))
right_0 = sge.func("COALESCE", right.expr, _literal(0, right.dtype))
right_1 = sge.func("COALESCE", right.expr, _literal(1, right.dtype))
if not is_floating_types:
return sge.And(
this=sge.EQ(this=left_0, expression=right_0),
expression=sge.EQ(this=left_1, expression=right_1),
)

left_2 = sge.If(
this=sge.IsNan(this=left.expr), true=_literal(2, left.dtype), false=left_0
)
left_3 = sge.If(
this=sge.IsNan(this=left.expr), true=_literal(3, left.dtype), false=left_1
)
right_2 = sge.If(
this=sge.IsNan(this=right.expr), true=_literal(2, right.dtype), false=right_0
)
right_3 = sge.If(
this=sge.IsNan(this=right.expr), true=_literal(3, right.dtype), false=right_1
)
return sge.And(
this=sge.EQ(this=left_2, expression=right_2),
expression=sge.EQ(this=left_3, expression=right_3),
)
5 changes: 3 additions & 2 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,9 @@ def is_json_encoding_type(type_: ExpressionType) -> bool:
return type_ != GEO_DTYPE


def is_numeric(type_: ExpressionType) -> bool:
return type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool:
is_numeric = type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
return is_numeric if include_bool else is_numeric and type_ != BOOL_DTYPE


def is_iterable(type_: ExpressionType) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ WITH `bfcte_1` AS (
*
FROM `bfcte_2`
LEFT JOIN `bfcte_3`
ON `bfcol_2` = `bfcol_6`
ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0)
AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1)
)
SELECT
`bfcol_3` AS `int64_col`,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
WITH `bfcte_1` AS (
SELECT
`bool_col` AS `bfcol_0`,
`rowindex` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_1`
), `bfcte_0` AS (
SELECT
`bool_col` AS `bfcol_4`,
`rowindex` AS `bfcol_5`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_3` AS (
SELECT
`bfcol_5` AS `bfcol_6`,
`bfcol_4` AS `bfcol_7`
FROM `bfcte_0`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_2`
INNER JOIN `bfcte_3`
ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0')
AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1')
)
SELECT
`bfcol_2` AS `rowindex_x`,
`bfcol_3` AS `bool_col`,
`bfcol_6` AS `rowindex_y`
FROM `bfcte_4`
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
WITH `bfcte_1` AS (
SELECT
`float64_col` AS `bfcol_0`,
`rowindex` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_1`
), `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_4`,
`rowindex` AS `bfcol_5`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_3` AS (
SELECT
`bfcol_5` AS `bfcol_6`,
`bfcol_4` AS `bfcol_7`
FROM `bfcte_0`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_2`
INNER JOIN `bfcte_3`
ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0))
AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1))
)
SELECT
`bfcol_2` AS `rowindex_x`,
`bfcol_3` AS `float64_col`,
`bfcol_6` AS `rowindex_y`
FROM `bfcte_4`
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
WITH `bfcte_1` AS (
SELECT
`int64_col` AS `bfcol_0`,
`rowindex` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_1`
), `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_4`,
`rowindex` AS `bfcol_5`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_3` AS (
SELECT
`bfcol_5` AS `bfcol_6`,
`bfcol_4` AS `bfcol_7`
FROM `bfcte_0`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_2`
INNER JOIN `bfcte_3`
ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0)
AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1)
)
SELECT
`bfcol_2` AS `rowindex_x`,
`bfcol_3` AS `int64_col`,
`bfcol_6` AS `rowindex_y`
FROM `bfcte_4`
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
WITH `bfcte_1` AS (
SELECT
`numeric_col` AS `bfcol_0`,
`rowindex` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_1`
), `bfcte_0` AS (
SELECT
`numeric_col` AS `bfcol_4`,
`rowindex` AS `bfcol_5`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_3` AS (
SELECT
`bfcol_5` AS `bfcol_6`,
`bfcol_4` AS `bfcol_7`
FROM `bfcte_0`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_2`
INNER JOIN `bfcte_3`
ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC))
AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC))
)
SELECT
`bfcol_2` AS `rowindex_x`,
`bfcol_3` AS `numeric_col`,
`bfcol_6` AS `rowindex_y`
FROM `bfcte_4`
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
WITH `bfcte_1` AS (
SELECT
`rowindex` AS `bfcol_0`,
`string_col` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_0` AS (
SELECT
`rowindex` AS `bfcol_2`,
`string_col` AS `bfcol_3`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_2` AS `bfcol_4`,
`bfcol_3` AS `bfcol_5`
FROM `bfcte_0`
), `bfcte_3` AS (
SELECT
*
FROM `bfcte_1`
INNER JOIN `bfcte_2`
ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0')
AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1')
)
SELECT
`bfcol_0` AS `rowindex_x`,
`bfcol_1` AS `string_col`,
`bfcol_4` AS `rowindex_y`
FROM `bfcte_3`
A3E2
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
WITH `bfcte_1` AS (
SELECT
`rowindex` AS `bfcol_0`,
`time_col` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_0` AS (
SELECT
`rowindex` AS `bfcol_2`,
`time_col` AS `bfcol_3`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_2` AS `bfcol_4`,
`bfcol_3` AS `bfcol_5`
FROM `bfcte_0`
), `bfcte_3` AS (
SELECT
*
FROM `bfcte_1`
INNER JOIN `bfcte_2`
ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0')
AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1')
)
SELECT
`bfcol_0` AS `rowindex_x`,
`bfcol_1` AS `time_col`,
`bfcol_4` AS `rowindex_y`
FROM `bfcte_3`
10 changes: 10 additions & 0 deletions tests/unit/core/compile/sqlglot/test_compile_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ def test_compile_join_w_how(scalar_types_df: bpd.DataFrame):
join_sql = left.merge(right, how="cross").sql
assert "CROSS JOIN" in join_sql
assert "ON" not in join_sql


@pytest.mark.parametrize(
("on"),
["bool_col", "int64_col", "float64_col", "string_col", "time_col", "numeric_col"],
)
def test_compile_join_w_on(scalar_types_df: bpd.DataFrame, on: str, snapshot):
df = scalar_types_df[["rowindex", on]]
merge = df.merge(df, left_on=on, right_on=on)
snapshot.assert_match(merge.sql, "out.sql")
0