8000 chore: migrate 10 scalar operators to SQLGlot by chelsea-lin · Pull Request #1922 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,52 @@
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64")
_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64")

# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result
# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10)
# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow.
_FLOAT64_EXP_BOUND = sge.convert(709.78)

UNARY_OP_REGISTRATION = OpRegistration()


def compile(op: ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return UNARY_OP_REGISTRATION[op](op, expr)


@UNARY_OP_REGISTRATION.register(ops.arccos_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Case(
ifs=[
sge.If(
this=sge.func("ABS", expr.expr) > sge.convert(1),
true=_NAN,
)
],
default=sge.func("ACOS", expr.expr),
)


@UNARY_OP_REGISTRATION.register(ops.arcsin_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Case(
ifs=[
sge.If(
this=sge.func("ABS", expr.expr) > sge.convert(1),
true=_NAN,
)
],
default=sge.func("ASIN", expr.expr),
)


@UNARY_OP_REGISTRATION.register(ops.arctan_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("ATAN", expr.expr)


@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp)
def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression:
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
Expand Down Expand Up @@ -72,6 +111,49 @@ def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression:
return sge.array(selected_elements)


@UNARY_OP_REGISTRATION.register(ops.cos_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("COS", expr.expr)


@UNARY_OP_REGISTRATION.register(ops.hash_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("FARM_FINGERPRINT", expr.expr)


@UNARY_OP_REGISTRATION.register(ops.isnull_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Is(this=expr.expr, expression=sge.Null())


@UNARY_OP_REGISTRATION.register(ops.notnull_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Not(this=sge.Is(this=expr.expr, expression=sge.Null()))


@UNARY_OP_REGISTRATION.register(ops.sin_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("SIN", expr.expr)


@UNARY_OP_REGISTRATION.register(ops.sinh_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.Case(
ifs=[
sge.If(
this=sge.func("ABS", expr.expr) > _FLOAT64_EXP_BOUND,
true=sge.func("SIGN", expr.expr) * _INF,
)
],
default=sge.func("SINH", expr.expr),
)


@UNARY_OP_REGISTRATION.register(ops.tan_op)
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
return sge.func("TAN", expr.expr)


# JSON Ops
@UNARY_OP_REGISTRATION.register(ops.JSONExtract)
def _(op: ops.JSONExtract, expr: TypedExpr) -> sge.Expression:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE WHEN ABS(`bfcol_0`) > 1 THEN CAST('NaN' AS FLOAT64) ELSE ACOS(`bfcol_0`) END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE WHEN ABS(`bfcol_0`) > 1 THEN CAST('NaN' AS FLOAT64) ELSE ASIN(`bfcol_0`) END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
ATAN(`bfcol_0`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
COS(`bfcol_0`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
FARM_FINGERPRINT(`bfcol_0`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `string_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
`bfcol_0` IS NULL AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
NOT `bfcol_0` IS NULL AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
SIN(`bfcol_0`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN ABS(`bfcol_0`) > 709.78
THEN SIGN(`bfcol_0`) * CAST('Infinity' AS FLOAT64)
ELSE SINH(`bfcol_0`)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
WITH `bfcte_0` AS (
SELECT
`float64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
TAN(`bfcol_0`) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `float64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ def _apply_binary_op(
def test_add_numeric(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["int64_col"]]
sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", "int64_col")

snapshot.assert_match(sql, "out.sql")


def test_add_numeric_w_scalar(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["int64_col"]]
sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", ex.const(1))

snapshot.assert_match(sql, "out.sql")


def test_add_string(scalar_types_df: bpd.DataFrame, snapshot):
bf_df = scalar_types_df[["string_col"]]
sql = _apply_binary_op(bf_df, ops.add_op, "string_col", ex.const("a"))

snapshot.assert_match(sql, "out.sql")


Expand All @@ -64,4 +67,5 @@ def test_json_set(json_types_df: bpd.DataFrame, snapshot):
sql = _apply_binary_op(
bf_df, ops.JSONSet(json_path="$.a"), "json_col", ex.const(100)
)

snapshot.assert_match(sql, "out.sql")
Loading
0