diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 9c81eda044..98dbed4cdd 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -175,7 +175,7 @@ def from_union( ), f"At least two select expressions must be provided, but got {selects}." existing_ctes: list[sge.CTE] = [] - union_selects: list[sge.Select] = [] + union_selects: list[sge.Expression] = [] for select in selects: assert isinstance( select, sge.Select @@ -204,10 +204,14 @@ def from_union( sge.Select().select(*selections).from_(sge.Table(this=new_cte_name)) ) - union_expr = sg.union( - *union_selects, - distinct=False, - copy=False, + union_expr = typing.cast( + sge.Select, + functools.reduce( + lambda x, y: sge.Union( + this=x, expression=y, distinct=False, copy=False + ), + union_selects, + ), ) final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery()) final_select_expr.set("with", sge.With(expressions=existing_ctes)) diff --git a/tests/system/small/engines/test_concat.py b/tests/system/small/engines/test_concat.py index e10570fab2..5786cfc419 100644 --- a/tests/system/small/engines/test_concat.py +++ b/tests/system/small/engines/test_concat.py @@ -24,7 +24,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_concat_self( scalars_array_value: array_value.ArrayValue, engine, @@ -34,7 +34,7 @@ def test_engines_concat_self( assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_concat_filtered_sorted( scalars_array_value: array_value.ArrayValue, engine, diff --git a/tests/system/small/engines/test_filtering.py b/tests/system/small/engines/test_filtering.py index 9b7cd034b4..817bb4c3f7 100644 --- a/tests/system/small/engines/test_filtering.py +++ b/tests/system/small/engines/test_filtering.py @@ -24,7 +24,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_filter_bool_col( scalars_array_value: array_value.ArrayValue, engine, @@ -35,7 +35,7 @@ def test_engines_filter_bool_col( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_filter_expr_cond( scalars_array_value: array_value.ArrayValue, engine, @@ -47,7 +47,7 @@ def test_engines_filter_expr_cond( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_filter_true( scalars_array_value: array_value.ArrayValue, engine, @@ -57,7 +57,7 @@ def test_engines_filter_true( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_filter_false( scalars_array_value: array_value.ArrayValue, engine, diff --git a/tests/system/small/engines/test_strings.py b/tests/system/small/engines/test_strings.py index cbab517ef0..d450474504 100644 --- a/tests/system/small/engines/test_strings.py +++ b/tests/system/small/engines/test_strings.py @@ -25,7 +25,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_str_contains(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ @@ -38,7 +38,7 @@ def test_engines_str_contains(scalars_array_value: array_value.ArrayValue, engin assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_str_contains_regex( scalars_array_value: array_value.ArrayValue, engine ): @@ -53,7 +53,7 @@ def test_engines_str_contains_regex( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_str_startswith(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ @@ -65,7 +65,7 @@ def test_engines_str_startswith(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_str_endswith(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql index 62e22a6a19..faff452761 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql @@ -1,78 +1,82 @@ WITH `bfcte_1` AS ( SELECT - * - FROM UNNEST(ARRAY>[STRUCT(0, 123456789, 0, 'Hello, World!', 0), STRUCT(1, -987654321, 1, 'こんにちは', 1), STRUCT(2, 314159, 2, ' ¡Hola Mundo! ', 2), STRUCT(3, CAST(NULL AS INT64), 3, CAST(NULL AS STRING), 3), STRUCT(4, -234892, 4, 'Hello, World!', 4), STRUCT(5, 55555, 5, 'Güten Tag!', 5), STRUCT(6, 101202303, 6, 'capitalize, This ', 6), STRUCT(7, -214748367, 7, ' سلام', 7), STRUCT(8, 2, 8, 'T', 8)]) + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1`, + `string_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_3` AS ( SELECT *, - `bfcol_4` AS `bfcol_10` + ROW_NUMBER() OVER () AS `bfcol_7` FROM `bfcte_1` ), `bfcte_5` AS ( SELECT *, - 0 AS `bfcol_16` + 0 AS `bfcol_8` FROM `bfcte_3` ), `bfcte_6` AS ( SELECT - `bfcol_0` AS `bfcol_17`, - `bfcol_2` AS `bfcol_18`, - `bfcol_1` AS `bfcol_19`, - `bfcol_3` AS `bfcol_20`, - `bfcol_16` AS `bfcol_21`, - `bfcol_10` AS `bfcol_22` + `bfcol_1` AS `bfcol_9`, + `bfcol_1` AS `bfcol_10`, + `bfcol_0` AS `bfcol_11`, + `bfcol_2` AS `bfcol_12`, + `bfcol_8` AS `bfcol_13`, + `bfcol_7` AS `bfcol_14` FROM `bfcte_5` ), `bfcte_0` AS ( SELECT - * - FROM UNNEST(ARRAY>[STRUCT(0, 123456789, 0, 'Hello, World!', 0), STRUCT(1, -987654321, 1, 'こんにちは', 1), STRUCT(2, 314159, 2, ' ¡Hola Mundo! ', 2), STRUCT(3, CAST(NULL AS INT64), 3, CAST(NULL AS STRING), 3), STRUCT(4, -234892, 4, 'Hello, World!', 4), STRUCT(5, 55555, 5, 'Güten Tag!', 5), STRUCT(6, 101202303, 6, 'capitalize, This ', 6), STRUCT(7, -214748367, 7, ' سلام', 7), STRUCT(8, 2, 8, 'T', 8)]) + `int64_col` AS `bfcol_15`, + `rowindex` AS `bfcol_16`, + `string_col` AS `bfcol_17` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_2` AS ( SELECT *, - `bfcol_27` AS `bfcol_33` + ROW_NUMBER() OVER () AS `bfcol_22` FROM `bfcte_0` ), `bfcte_4` AS ( SELECT *, - 1 AS `bfcol_39` + 1 AS `bfcol_23` FROM `bfcte_2` ), `bfcte_7` AS ( SELECT - `bfcol_23` AS `bfcol_40`, - `bfcol_25` AS `bfcol_41`, - `bfcol_24` AS `bfcol_42`, - `bfcol_26` AS `bfcol_43`, - `bfcol_39` AS `bfcol_44`, - `bfcol_33` AS `bfcol_45` + `bfcol_16` AS `bfcol_24`, + `bfcol_16` AS `bfcol_25`, + `bfcol_15` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_23` AS `bfcol_28`, + `bfcol_22` AS `bfcol_29` FROM `bfcte_4` ), `bfcte_8` AS ( SELECT * FROM ( SELECT - `bfcol_17` AS `bfcol_46`, - `bfcol_18` AS `bfcol_47`, - `bfcol_19` AS `bfcol_48`, - `bfcol_20` AS `bfcol_49`, - `bfcol_21` AS `bfcol_50`, - `bfcol_22` AS `bfcol_51` + `bfcol_9` AS `bfcol_30`, + `bfcol_10` AS `bfcol_31`, + `bfcol_11` AS `bfcol_32`, + `bfcol_12` AS `bfcol_33`, + `bfcol_13` AS `bfcol_34`, + `bfcol_14` AS `bfcol_35` FROM `bfcte_6` UNION ALL SELECT - `bfcol_40` AS `bfcol_46`, - `bfcol_41` AS `bfcol_47`, - `bfcol_42` AS `bfcol_48`, - `bfcol_43` AS `bfcol_49`, - `bfcol_44` AS `bfcol_50`, - `bfcol_45` AS `bfcol_51` + `bfcol_24` AS `bfcol_30`, + `bfcol_25` AS `bfcol_31`, + `bfcol_26` AS `bfcol_32`, + `bfcol_27` AS `bfcol_33`, + `bfcol_28` AS `bfcol_34`, + `bfcol_29` AS `bfcol_35` FROM `bfcte_7` ) ) SELECT - `bfcol_46` AS `rowindex`, - `bfcol_47` AS `rowindex_1`, - `bfcol_48` AS `int64_col`, - `bfcol_49` AS `string_col` + `bfcol_30` AS `rowindex`, + `bfcol_31` AS `rowindex_1`, + `bfcol_32` AS `int64_col`, + `bfcol_33` AS `string_col` FROM `bfcte_8` ORDER BY - `bfcol_50` ASC NULLS LAST, - `bfcol_51` ASC NULLS LAST \ No newline at end of file + `bfcol_34` ASC NULLS LAST, + `bfcol_35` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat_filter_sorted/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat_filter_sorted/out.sql new file mode 100644 index 0000000000..5043435688 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat_filter_sorted/out.sql @@ -0,0 +1,142 @@ +WITH `bfcte_3` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_7` AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_4` + FROM `bfcte_3` +), `bfcte_11` AS ( + SELECT + *, + 0 AS `bfcol_5` + FROM `bfcte_7` +), `bfcte_14` AS ( + SELECT + `bfcol_1` AS `bfcol_6`, + `bfcol_0` AS `bfcol_7`, + `bfcol_5` AS `bfcol_8`, + `bfcol_4` AS `bfcol_9` + FROM `bfcte_11` +), `bfcte_2` AS ( + SELECT + `bool_col` AS `bfcol_10`, + `int64_too` AS `bfcol_11`, + `float64_col` AS `bfcol_12` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_6` AS ( + SELECT + * + FROM `bfcte_2` + WHERE + `bfcol_10` +), `bfcte_10` AS ( + SELECT + *, + ROW_NUMBER() OVER () AS `bfcol_15` + FROM `bfcte_6` +), `bfcte_13` AS ( + SELECT + *, + 1 AS `bfcol_16` + FROM `bfcte_10` +), `bfcte_15` AS ( + SELECT + `bfcol_12` AS `bfcol_17`, + `bfcol_11` AS `bfcol_18`, + `bfcol_16` AS `bfcol_19`, + `bfcol_15` AS `bfcol_20` + FROM `bfcte_13` +), `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_21`, + `float64_col` AS `bfcol_22` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_5` AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY `bfcol_21` IS NULL ASC NULLS LAST, `bfcol_21` ASC NULLS LAST) AS `bfcol_25` + FROM `bfcte_1` +), `bfcte_9` AS ( + SELECT + *, + 2 AS `bfcol_26` + FROM `bfcte_5` +), `bfcte_16` AS ( + SELECT + `bfcol_22` AS `bfcol_27`, + `bfcol_21` AS `bfcol_28`, + `bfcol_26` AS `bfcol_29`, + `bfcol_25` AS `bfcol_30` + FROM `bfcte_9` +), `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_31`, + `int64_too` AS `bfcol_32`, + `float64_col` AS `bfcol_33` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_0` + WHERE + `bfcol_31` +), `bfcte_8` AS ( + SELECT + *, + ROW_NUMBER() OVER () AS `bfcol_36` + FROM `bfcte_4` +), `bfcte_12` AS ( + SELECT + *, + 3 AS `bfcol_37` + FROM `bfcte_8` +), `bfcte_17` AS ( + SELECT + `bfcol_33` AS `bfcol_38`, + `bfcol_32` AS `bfcol_39`, + `bfcol_37` AS `bfcol_40`, + `bfcol_36` AS `bfcol_41` + FROM `bfcte_12` +), `bfcte_18` AS ( + SELECT + * + FROM ( + SELECT + `bfcol_6` AS `bfcol_42`, + `bfcol_7` AS `bfcol_43`, + `bfcol_8` AS `bfcol_44`, + `bfcol_9` AS `bfcol_45` + FROM `bfcte_14` + UNION ALL + SELECT + `bfcol_17` AS `bfcol_42`, + `bfcol_18` AS `bfcol_43`, + `bfcol_19` AS `bfcol_44`, + `bfcol_20` AS `bfcol_45` + FROM `bfcte_15` + UNION ALL + SELECT + `bfcol_27` AS `bfcol_42`, + `bfcol_28` AS `bfcol_43`, + `bfcol_29` AS `bfcol_44`, + `bfcol_30` AS `bfcol_45` + FROM `bfcte_16` + UNION ALL + SELECT + `bfcol_38` AS `bfcol_42`, + `bfcol_39` AS `bfcol_43`, + `bfcol_40` AS `bfcol_44`, + `bfcol_41` AS `bfcol_45` + FROM `bfcte_17` + ) +) +SELECT + `bfcol_42` AS `float64_col`, + `bfcol_43` AS `int64_col` +FROM `bfcte_18` +ORDER BY + `bfcol_44` ASC NULLS LAST, + `bfcol_45` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_concat.py b/tests/unit/core/compile/sqlglot/test_compile_concat.py index 79f73d3113..c176b2e116 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_concat.py +++ b/tests/unit/core/compile/sqlglot/test_compile_concat.py @@ -12,21 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd import pytest -import bigframes +from bigframes.core import ordering import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") -def test_compile_concat( - scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot -): +def test_compile_concat(scalar_types_df: bpd.DataFrame, snapshot): # TODO: concat two same dataframes, which SQL does not get reused. - # TODO: concat dataframes from a gbq table but trigger a windows compiler. - df1 = bpd.DataFrame(scalar_types_pandas_df, session=compiler_session) - df1 = df1[["rowindex", "int64_col", "string_col"]] + df1 = scalar_types_df[["rowindex", "int64_col", "string_col"]] concat_df = bpd.concat([df1, df1]) snapshot.assert_match(concat_df.sql, "out.sql") + + +def test_compile_concat_filter_sorted(scalar_types_df: bpd.DataFrame, snapshot): + + scalars_array_value = scalar_types_df._block.expr + input_1 = scalars_array_value.select_columns(["float64_col", "int64_col"]).order_by( + [ordering.ascending_over("int64_col")] + ) + input_2 = scalars_array_value.filter_by_id("bool_col").select_columns( + ["float64_col", "int64_too"] + ) + + result = input_1.concat([input_2, input_1, input_2]) + + new_names = ["float64_col", "int64_col"] + col_ids = { + old_name: new_name for old_name, new_name in zip(result.column_ids, new_names) + } + result = result.rename_columns(col_ids).select_columns(new_names) + + sql = result.session._executor.to_sql(result, enable_cache=False) + snapshot.assert_match(sql, "out.sql")