From e493025eb91e74b8ab4cd31a1cdfc234f4055ac1 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 1 May 2025 20:09:49 +0000 Subject: [PATCH 1/2] refactor: add uid generator and encasualate query as cte --- .pre-commit-config.yaml | 1 + bigframes/core/compile/sqlglot/compiler.py | 103 +++--- bigframes/core/compile/sqlglot/sqlglot_ir.py | 44 ++- bigframes/core/guid.py | 18 + bigframes/core/rewrite/identifiers.py | 24 +- .../core/compile/sqlglot/compiler_session.py | 7 +- .../test_compile_readlocal/out.sql | 313 +++++++++--------- .../test_compile_readlocal_w_json_df/out.sql | 7 +- .../test_compile_readlocal_w_lists_df/out.sql | 67 ++-- .../out.sql | 39 ++- 10 files changed, 343 insertions(+), 280 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 863a345da1..7e46c73d0d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,7 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer + exclude: "^tests/unit/core/compile/sqlglot/snapshots" - id: check-yaml - repo: https://github.com/pycqa/isort rev: 5.12.0 diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index cb510ce365..47fc30b83f 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -15,14 +15,13 @@ import dataclasses import functools -import itertools import typing from google.cloud import bigquery import pyarrow as pa import sqlglot.expressions as sge -from bigframes.core import expression, identifiers, nodes, rewrite +from bigframes.core import expression, guid, nodes, rewrite from bigframes.core.compile import configs import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir @@ -33,6 +32,9 @@ class SQLGlotCompiler: """Compiles BigFrame nodes into SQL using SQLGlot.""" + uid_gen: guid.SequentialUIDGenerator + """Generator for unique identifiers.""" + def compile( self, node: nodes.BigFrameNode, @@ -82,8 +84,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult result_node = typing.cast( nodes.ResultNode, rewrite.column_pruning(result_node) ) - result_node = _remap_variables(result_node) - sql = self._compile_result_node(result_node) + remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen) + sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node)) return configs.CompileResult( sql, result_node.schema.to_bigquery(), result_node.order_by ) @@ -92,8 +94,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult result_node = dataclasses.replace(result_node, order_by=None) result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) - result_node = _remap_variables(result_node) - sql = self._compile_result_node(result_node) + remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen) + sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node)) # Return the ordering iff no extra columns are needed to define the row order if ordering is not None: output_order = ( @@ -107,62 +109,53 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult ) def _compile_result_node(self, root: nodes.ResultNode) -> str: - sqlglot_ir = compile_node(root.child) + sqlglot_ir = self.compile_node(root.child) # TODO: add order_by, limit, and selections to sqlglot_expr return sqlglot_ir.sql + @functools.lru_cache(maxsize=5000) + def compile_node(self, node: nodes.BigFrameNode) -> ir.SQLGlotIR: + """Compiles node into CompileArrayValue. Caches result.""" + return node.reduce_up( + lambda node, children: self._compile_node(node, *children) + ) -def _replace_unsupported_ops(node: nodes.BigFrameNode): - node = nodes.bottom_up(node, rewrite.rewrite_slice) - node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions) - node = nodes.bottom_up(node, rewrite.rewrite_range_rolling) - return node - - -def _remap_variables(node: nodes.ResultNode) -> nodes.ResultNode: - """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs.""" - - def anonymous_column_ids() -> typing.Generator[identifiers.ColumnId, None, None]: - for i in itertools.count(): - yield identifiers.ColumnId(name=f"bfcol_{i}") - - result_node, _ = rewrite.remap_variables(node, anonymous_column_ids()) - return typing.cast(nodes.ResultNode, result_node) - - -@functools.lru_cache(maxsize=5000) -def compile_node(node: nodes.BigFrameNode) -> ir.SQLGlotIR: - """Compiles node into CompileArrayValue. Caches result.""" - return node.reduce_up(lambda node, children: _compile_node(node, *children)) - - -@functools.singledispatch -def _compile_node( - node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR -) -> ir.SQLGlotIR: - """Defines transformation but isn't cached, always use compile_node instead""" - raise ValueError(f"Can't compile unrecognized node: {node}") + @functools.singledispatchmethod + def _compile_node( + self, node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + """Defines transformation but isn't cached, always use compile_node instead""" + raise ValueError(f"Can't compile unrecognized node: {node}") + + @_compile_node.register + def compile_readlocal(self, node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR: + pa_table = node.local_data_source.data + pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) + pa_table = pa_table.rename_columns( + [item.id.sql for item in node.scan_list.items] + ) + offsets = node.offsets_col.sql if node.offsets_col else None + if offsets: + pa_table = pa_table.append_column( + offsets, pa.array(range(pa_table.num_rows), type=pa.int64()) + ) -@_compile_node.register -def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR: - pa_table = node.local_data_source.data - pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) - pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items]) + return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=self.uid_gen) - offsets = node.offsets_col.sql if node.offsets_col else None - if offsets: - pa_table = pa_table.append_column( - offsets, pa.array(range(pa_table.num_rows), type=pa.int64()) + @_compile_node.register + def compile_selection( + self, node: nodes.SelectionNode, child: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( + (id.sql, scalar_compiler.compile_scalar_expression(expr)) + for expr, id in node.input_output_pairs ) - - return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema) + return child.select(selected_cols) -@_compile_node.register -def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: - select_cols: typing.Dict[str, sge.Expression] = { - id.name: scalar_compiler.compile_scalar_expression(expr) - for expr, id in node.input_output_pairs - } - return child.select(select_cols) +def _replace_unsupported_ops(node: nodes.BigFrameNode): + node = nodes.bottom_up(node, rewrite.rewrite_slice) + node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions) + node = nodes.bottom_up(node, rewrite.rewrite_range_rolling) + return node diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 607e712a2b..24eef41fda 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -23,6 +23,7 @@ import sqlglot.expressions as sge from bigframes import dtypes +from bigframes.core import guid import bigframes.core.compile.sqlglot.sqlglot_types as sgt import bigframes.core.local_data as local_data import bigframes.core.schema as schemata @@ -52,6 +53,9 @@ class SQLGlotIR: pretty: bool = True """Whether to pretty-print the generated SQL.""" + uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() + """Generator for unique identifiers.""" + @property def sql(self) -> str: """Generate SQL string from the given expression.""" @@ -59,7 +63,10 @@ def sql(self) -> str: @classmethod def from_pyarrow( - cls, pa_table: pa.Table, schema: schemata.ArraySchema + cls, + pa_table: pa.Table, + schema: schemata.ArraySchema, + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Builds SQLGlot expression from pyarrow table.""" dtype_expr = sge.DataType( @@ -95,21 +102,44 @@ def from_pyarrow( ), ], ) - return cls(expr=sg.select(sge.Star()).from_(expr)) + return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen) def select( self, - select_cols: typing.Dict[str, sge.Expression], + selected_cols: tuple[tuple[str, sge.Expression], ...], ) -> SQLGlotIR: - selected_cols = [ + cols_expr = [ sge.Alias( this=expr, alias=sge.to_identifier(id, quoted=self.quoted), ) - for id, expr in select_cols.items() + for id, expr in selected_cols ] - expr = self.expr.select(*selected_cols, append=False) - return SQLGlotIR(expr=expr) + new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False) + return SQLGlotIR(expr=new_expr) + + def _encapsulate_as_cte( + self, + ) -> sge.Select: + """Transforms a given sge.Select query by pushing its main SELECT statement + into a new CTE and then generates a 'SELECT * FROM new_cte_name' + for the new query.""" + select_expr = self.expr.copy() + + existing_ctes = select_expr.args.pop("with", []) + new_cte_name = sge.to_identifier( + self.uid_gen.generate_sequential_uid("bfcte_"), quoted=self.quoted + ) + new_cte = sge.CTE( + this=select_expr, + alias=new_cte_name, + ) + new_with_clause = sge.With(expressions=existing_ctes + [new_cte]) + new_select_expr = ( + sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name)) + ) + new_select_expr.set("with", new_with_clause) + return new_select_expr def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: diff --git a/bigframes/core/guid.py b/bigframes/core/guid.py index 8930d0760a..cb3094c0e2 100644 --- a/bigframes/core/guid.py +++ b/bigframes/core/guid.py @@ -19,3 +19,21 @@ def generate_guid(prefix="col_"): global _GUID_COUNTER _GUID_COUNTER += 1 return f"bfuid_{prefix}{_GUID_COUNTER}" + + +class SequentialUIDGenerator: + """ + Generates sequential-like UIDs with multiple prefixes, e.g., "t0", "t1", "c0", "t2", etc. + """ + + def __init__(self): + self.prefix_counters = {} + + def generate_sequential_uid(self, prefix: str) -> str: + """Generates a sequential UID with specified prefix.""" + if prefix not in self.prefix_counters: + self.prefix_counters[prefix] = 0 + + uid = f"{prefix}{self.prefix_counters[prefix]}" + self.prefix_counters[prefix] += 1 + return uid diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index d49e5c1b42..e09ef2e519 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -13,22 +13,17 @@ # limitations under the License. from __future__ import annotations -from typing import Generator, Tuple +from typing import Tuple -import bigframes.core.identifiers -import bigframes.core.nodes +from bigframes.core import guid, identifiers, nodes # TODO: May as well just outright remove selection nodes in this process. def remap_variables( - root: bigframes.core.nodes.BigFrameNode, - id_generator: Generator[bigframes.core.identifiers.ColumnId, None, None], -) -> Tuple[ - bigframes.core.nodes.BigFrameNode, - dict[bigframes.core.identifiers.ColumnId, bigframes.core.identifiers.ColumnId], -]: - """ - Remap all variables in the BFET using the id_generator. + root: nodes.BigFrameNode, + uid_gen: guid.SequentialUIDGenerator, +) -> Tuple[nodes.BigFrameNode, dict[identifiers.ColumnId, identifiers.ColumnId],]: + """Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs. Note: this will convert a DAG to a tree. """ @@ -36,7 +31,7 @@ def remap_variables( ref_mapping = dict() # Sequential ids are assigned bottom-up left-to-right for child in root.child_nodes: - new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) + new_child, child_var_mapping = remap_variables(child, uid_gen=uid_gen) child_replacement_map[child] = new_child ref_mapping.update(child_var_mapping) @@ -47,7 +42,10 @@ def remap_variables( with_new_refs = with_new_children.remap_refs(ref_mapping) - node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} + node_var_mapping = { + old_id: identifiers.ColumnId(name=uid_gen.generate_sequential_uid("bfcol_")) + for old_id in root.node_defined_ids + } with_new_vars = with_new_refs.remap_vars(node_var_mapping) with_new_vars._validate() diff --git a/tests/unit/core/compile/sqlglot/compiler_session.py b/tests/unit/core/compile/sqlglot/compiler_session.py index eddae8f891..67896e2e41 100644 --- a/tests/unit/core/compile/sqlglot/compiler_session.py +++ b/tests/unit/core/compile/sqlglot/compiler_session.py @@ -18,6 +18,7 @@ import bigframes.core import bigframes.core.compile.sqlglot as sqlglot +import bigframes.core.guid import bigframes.dataframe import bigframes.session.executor import bigframes.session.metrics @@ -27,7 +28,7 @@ class SQLCompilerExecutor(bigframes.session.executor.Executor): """Executor for SQL compilation using sqlglot.""" - compiler = sqlglot.SQLGlotCompiler() + compiler = sqlglot def to_sql( self, @@ -41,7 +42,9 @@ def to_sql( # Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips # caching the subtree. - return self.compiler.compile(array_value.node, ordered=ordered) + return self.compiler.SQLGlotCompiler( + uid_gen=bigframes.core.guid.SequentialUIDGenerator() + ).compile(array_value.node, ordered=ordered) class SQLCompilerSession(bigframes.session.Session): diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql index 0ef80dc8b0..f04f9ed023 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql @@ -1,3 +1,161 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>[STRUCT( + 0, + TRUE, + CAST(b'Hello, World!' AS BYTES), + CAST('2021-07-21' AS DATE), + CAST('2021-07-21T11:39:45' AS DATETIME), + ST_GEOGFROMTEXT('POINT (-122.0838511 37.3860517)'), + 123456789, + 0, + 1.234567890, + 1.25, + 0, + 0, + 'Hello, World!', + CAST('11:41:43.076160' AS TIME), + CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + 0 + ), STRUCT( + 1, + FALSE, + CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), + CAST('1991-02-03' AS DATE), + CAST('1991-01-02T03:45:06' AS DATETIME), + ST_GEOGFROMTEXT('POINT (-71.104 42.315)'), + -987654321, + 1, + 1.234567890, + 2.51, + 1, + 1, + 'こんにちは', + CAST('11:14:34.701606' AS TIME), + CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + 1 + ), STRUCT( + 2, + TRUE, + CAST(b'\xc2\xa1Hola Mundo!' AS BYTES), + CAST('2023-03-01' AS DATE), + CAST('2023-03-01T10:55:13' AS DATETIME), + ST_GEOGFROMTEXT('POINT (-0.124474760143016 51.5007826749545)'), + 314159, + 0, + 101.101010100, + 25000000000.0, + 2, + 2, + ' ¡Hola Mundo! ', + CAST('23:59:59.999999' AS TIME), + CAST('2023-03-01T10:55:13.250125+00:00' AS TIMESTAMP), + 2 + ), STRUCT( + 3, + CAST(NULL AS BOOLEAN), + CAST(NULL AS BYTES), + CAST(NULL AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + CAST(NULL AS INT64), + 1, + CAST(NULL AS NUMERIC), + CAST(NULL AS FLOAT64), + 3, + 3, + CAST(NULL AS STRING), + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + 3 + ), STRUCT( + 4, + FALSE, + CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), + CAST('2021-07-21' AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + -234892, + -2345, + CAST(NULL AS NUMERIC), + CAST(NULL AS FLOAT64), + 4, + 4, + 'Hello, World!', + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + 4 + ), STRUCT( + 5, + FALSE, + CAST(b'G\xc3\xbcten Tag' AS BYTES), + CAST('1980-03-14' AS DATE), + CAST('1980-03-14T15:16:17' AS DATETIME), + CAST(NULL AS GEOGRAPHY), + 55555, + 0, + 5.555555000, + 555.555, + 5, + 5, + 'Güten Tag!', + CAST('15:16:17.181921' AS TIME), + CAST('1980-03-14T15:16:17.181921+00:00' AS TIMESTAMP), + 5 + ), STRUCT( + 6, + TRUE, + CAST(b'Hello\tBigFrames!\x07' AS BYTES), + CAST('2023-05-23' AS DATE), + CAST('2023-05-23T11:37:01' AS DATETIME), + ST_GEOGFROMTEXT('LINESTRING (-0.127959 51.507728, -0.127026 51.507473)'), + 101202303, + 2, + -10.090807000, + -123.456, + 6, + 6, + 'capitalize, This ', + CAST('01:02:03.456789' AS TIME), + CAST('2023-05-23T11:42:55.000001+00:00' AS TIMESTAMP), + 6 + ), STRUCT( + 7, + TRUE, + CAST(NULL AS BYTES), + CAST('2038-01-20' AS DATE), + CAST('2038-01-19T03:14:08' AS DATETIME), + CAST(NULL AS GEOGRAPHY), + -214748367, + 2, + 11111111.100000000, + 42.42, + 7, + 7, + ' سلام', + CAST('12:00:00.000001' AS TIME), + CAST('2038-01-19T03:14:17.999999+00:00' AS TIMESTAMP), + 7 + ), STRUCT( + 8, + FALSE, + CAST(NULL AS BYTES), + CAST(NULL AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + 2, + 1, + CAST(NULL AS NUMERIC), + 6.87, + 8, + 8, + 'T', + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + 8 + )]) +) SELECT `bfcol_0` AS `bfcol_16`, `bfcol_1` AS `bfcol_17`, @@ -15,157 +173,4 @@ SELECT `bfcol_13` AS `bfcol_29`, `bfcol_14` AS `bfcol_30`, `bfcol_15` AS `bfcol_31` -FROM UNNEST(ARRAY>[STRUCT( - 0, - TRUE, - CAST(b'Hello, World!' AS BYTES), - CAST('2021-07-21' AS DATE), - CAST('2021-07-21T11:39:45' AS DATETIME), - ST_GEOGFROMTEXT('POINT (-122.0838511 37.3860517)'), - 123456789, - 0, - 1.234567890, - 1.25, - 0, - 0, - 'Hello, World!', - CAST('11:41:43.076160' AS TIME), - CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), - 0 -), STRUCT( - 1, - FALSE, - CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), - CAST('1991-02-03' AS DATE), - CAST('1991-01-02T03:45:06' AS DATETIME), - ST_GEOGFROMTEXT('POINT (-71.104 42.315)'), - -987654321, - 1, - 1.234567890, - 2.51, - 1, - 1, - 'こんにちは', - CAST('11:14:34.701606' AS TIME), - CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), - 1 -), STRUCT( - 2, - TRUE, - CAST(b'\xc2\xa1Hola Mundo!' AS BYTES), - CAST('2023-03-01' AS DATE), - CAST('2023-03-01T10:55:13' AS DATETIME), - ST_GEOGFROMTEXT('POINT (-0.124474760143016 51.5007826749545)'), - 314159, - 0, - 101.101010100, - 25000000000.0, - 2, - 2, - ' ¡Hola Mundo! ', - CAST('23:59:59.999999' AS TIME), - CAST('2023-03-01T10:55:13.250125+00:00' AS TIMESTAMP), - 2 -), STRUCT( - 3, - CAST(NULL AS BOOLEAN), - CAST(NULL AS BYTES), - CAST(NULL AS DATE), - CAST(NULL AS DATETIME), - CAST(NULL AS GEOGRAPHY), - CAST(NULL AS INT64), - 1, - CAST(NULL AS NUMERIC), - CAST(NULL AS FLOAT64), - 3, - 3, - CAST(NULL AS STRING), - CAST(NULL AS TIME), - CAST(NULL AS TIMESTAMP), - 3 -), STRUCT( - 4, - FALSE, - CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), - CAST('2021-07-21' AS DATE), - CAST(NULL AS DATETIME), - CAST(NULL AS GEOGRAPHY), - -234892, - -2345, - CAST(NULL AS NUMERIC), - CAST(NULL AS FLOAT64), - 4, - 4, - 'Hello, World!', - CAST(NULL AS TIME), - CAST(NULL AS TIMESTAMP), - 4 -), STRUCT( - 5, - FALSE, - CAST(b'G\xc3\xbcten Tag' AS BYTES), - CAST('1980-03-14' AS DATE), - CAST('1980-03-14T15:16:17' AS DATETIME), - CAST(NULL AS GEOGRAPHY), - 55555, - 0, - 5.555555000, - 555.555, - 5, - 5, - 'Güten Tag!', - CAST('15:16:17.181921' AS TIME), - CAST('1980-03-14T15:16:17.181921+00:00' AS TIMESTAMP), - 5 -), STRUCT( - 6, - TRUE, - CAST(b'Hello\tBigFrames!\x07' AS BYTES), - CAST('2023-05-23' AS DATE), - CAST('2023-05-23T11:37:01' AS DATETIME), - ST_GEOGFROMTEXT('LINESTRING (-0.127959 51.507728, -0.127026 51.507473)'), - 101202303, - 2, - -10.090807000, - -123.456, - 6, - 6, - 'capitalize, This ', - CAST('01:02:03.456789' AS TIME), - CAST('2023-05-23T11:42:55.000001+00:00' AS TIMESTAMP), - 6 -), STRUCT( - 7, - TRUE, - CAST(NULL AS BYTES), - CAST('2038-01-20' AS DATE), - CAST('2038-01-19T03:14:08' AS DATETIME), - CAST(NULL AS GEOGRAPHY), - -214748367, - 2, - 11111111.100000000, - 42.42, - 7, - 7, - ' سلام', - CAST('12:00:00.000001' AS TIME), - CAST('2038-01-19T03:14:17.999999+00:00' AS TIMESTAMP), - 7 -), STRUCT( - 8, - FALSE, - CAST(NULL AS BYTES), - CAST(NULL AS DATE), - CAST(NULL AS DATETIME), - CAST(NULL AS GEOGRAPHY), - 2, - 1, - CAST(NULL AS NUMERIC), - 6.87, - 8, - 8, - 'T', - CAST(NULL AS TIME), - CAST(NULL AS TIMESTAMP), - 8 -)]) \ No newline at end of file +FROM `bfcte_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql index 3b780e6d8e..c0e5a0a476 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql @@ -1,4 +1,9 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>[STRUCT(PARSE_JSON('null'), 0), STRUCT(PARSE_JSON('true'), 1), STRUCT(PARSE_JSON('100'), 2), STRUCT(PARSE_JSON('0.98'), 3), STRUCT(PARSE_JSON('"a string"'), 4), STRUCT(PARSE_JSON('[]'), 5), STRUCT(PARSE_JSON('[1,2,3]'), 6), STRUCT(PARSE_JSON('[{"a":1},{"a":2},{"a":null},{}]'), 7), STRUCT(PARSE_JSON('"100"'), 8), STRUCT(PARSE_JSON('{"date":"2024-07-16"}'), 9), STRUCT(PARSE_JSON('{"int_value":2,"null_filed":null}'), 10), STRUCT(PARSE_JSON('{"list_data":[10,20,30]}'), 11)]) +) SELECT `bfcol_0` AS `bfcol_2`, `bfcol_1` AS `bfcol_3` -FROM UNNEST(ARRAY>[STRUCT(PARSE_JSON('null'), 0), STRUCT(PARSE_JSON('true'), 1), STRUCT(PARSE_JSON('100'), 2), STRUCT(PARSE_JSON('0.98'), 3), STRUCT(PARSE_JSON('"a string"'), 4), STRUCT(PARSE_JSON('[]'), 5), STRUCT(PARSE_JSON('[1,2,3]'), 6), STRUCT(PARSE_JSON('[{"a":1},{"a":2},{"a":null},{}]'), 7), STRUCT(PARSE_JSON('"100"'), 8), STRUCT(PARSE_JSON('{"date":"2024-07-16"}'), 9), STRUCT(PARSE_JSON('{"int_value":2,"null_filed":null}'), 10), STRUCT(PARSE_JSON('{"list_data":[10,20,30]}'), 11)]) \ No newline at end of file +FROM `bfcte_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql index 6998b41b27..c97babdaef 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql @@ -1,3 +1,38 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY, `bfcol_2` ARRAY, `bfcol_3` ARRAY, `bfcol_4` ARRAY, `bfcol_5` ARRAY, `bfcol_6` ARRAY, `bfcol_7` ARRAY, `bfcol_8` INT64>>[STRUCT( + 0, + [1], + [TRUE], + [1.2, 2.3], + ['2021-07-21'], + ['2021-07-21 11:39:45'], + [1.2, 2.3, 3.4], + ['abc', 'de', 'f'], + 0 + ), STRUCT( + 1, + [1, 2], + [TRUE, FALSE], + [1.1], + ['2021-07-21', '1987-03-28'], + ['1999-03-14 17:22:00'], + [5.5, 2.3], + ['a', 'bc', 'de'], + 1 + ), STRUCT( + 2, + [1, 2, 3], + [TRUE], + [0.5, -1.9, 2.3], + ['2017-08-01', '2004-11-22'], + ['1979-06-03 03:20:45'], + [1.7000000000000002], + ['', 'a'], + 2 + )]) +) SELECT `bfcol_0` AS `bfcol_9`, `bfcol_1` AS `bfcol_10`, @@ -8,34 +43,4 @@ SELECT `bfcol_6` AS `bfcol_15`, `bfcol_7` AS `bfcol_16`, `bfcol_8` AS `bfcol_17` -FROM UNNEST(ARRAY, `bfcol_2` ARRAY, `bfcol_3` ARRAY, `bfcol_4` ARRAY, `bfcol_5` ARRAY, `bfcol_6` ARRAY, `bfcol_7` ARRAY, `bfcol_8` INT64>>[STRUCT( - 0, - [1], - [TRUE], - [1.2, 2.3], - ['2021-07-21'], - ['2021-07-21 11:39:45'], - [1.2, 2.3, 3.4], - ['abc', 'de', 'f'], - 0 -), STRUCT( - 1, - [1, 2], - [TRUE, FALSE], - [1.1], - ['2021-07-21', '1987-03-28'], - ['1999-03-14 17:22:00'], - [5.5, 2.3], - ['a', 'bc', 'de'], - 1 -), STRUCT( - 2, - [1, 2, 3], - [TRUE], - [0.5, -1.9, 2.3], - ['2017-08-01', '2004-11-22'], - ['1979-06-03 03:20:45'], - [1.7000000000000002], - ['', 'a'], - 2 -)]) \ No newline at end of file +FROM `bfcte_0` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql index 99b94915bf..509e63e029 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql @@ -1,21 +1,26 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>, `bfcol_2` INT64>>[STRUCT( + 1, + STRUCT( + 'Alice' AS `name`, + 30 AS `age`, + STRUCT('New York' AS `city`, 'USA' AS `country`) AS `address` + ), + 0 + ), STRUCT( + 2, + STRUCT( + 'Bob' AS `name`, + 25 AS `age`, + STRUCT('London' AS `city`, 'UK' AS `country`) AS `address` + ), + 1 + )]) +) SELECT `bfcol_0` AS `bfcol_3`, `bfcol_1` AS `bfcol_4`, `bfcol_2` AS `bfcol_5` -FROM UNNEST(ARRAY>, `bfcol_2` INT64>>[STRUCT( - 1, - STRUCT( - 'Alice' AS `name`, - 30 AS `age`, - STRUCT('New York' AS `city`, 'USA' AS `country`) AS `address` - ), - 0 -), STRUCT( - 2, - STRUCT( - 'Bob' AS `name`, - 25 AS `age`, - STRUCT('London' AS `city`, 'UK' AS `country`) AS `address` - ), - 1 -)]) \ No newline at end of file +FROM `bfcte_0` \ No newline at end of file From b37e6ac97f2df443d9a122eeebb71fed9a002f96 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 2 May 2025 18:26:19 +0000 Subject: [PATCH 2/2] address comments --- bigframes/core/compile/sqlglot/compiler.py | 22 ++++++++++++++----- bigframes/core/compile/sqlglot/sqlglot_ir.py | 2 +- bigframes/core/guid.py | 20 ++++++++++------- bigframes/core/rewrite/identifiers.py | 18 +++++++-------- .../core/compile/sqlglot/compiler_session.py | 7 +++--- 5 files changed, 41 insertions(+), 28 deletions(-) diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 47fc30b83f..f6d63531da 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -21,20 +21,22 @@ import pyarrow as pa import sqlglot.expressions as sge -from bigframes.core import expression, guid, nodes, rewrite +from bigframes.core import expression, guid, identifiers, nodes, rewrite from bigframes.core.compile import configs import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir import bigframes.core.ordering as bf_ordering -@dataclasses.dataclass(frozen=True) class SQLGlotCompiler: """Compiles BigFrame nodes into SQL using SQLGlot.""" uid_gen: guid.SequentialUIDGenerator """Generator for unique identifiers.""" + def __init__(self): + self.uid_gen = guid.SequentialUIDGenerator() + def compile( self, node: nodes.BigFrameNode, @@ -84,8 +86,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult result_node = typing.cast( nodes.ResultNode, rewrite.column_pruning(result_node) ) - remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen) - sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node)) + result_node = self._remap_variables(result_node) + sql = self._compile_result_node(result_node) return configs.CompileResult( sql, result_node.schema.to_bigquery(), result_node.order_by ) @@ -94,8 +96,8 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult result_node = dataclasses.replace(result_node, order_by=None) result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) - remap_node, _ = rewrite.remap_variables(result_node, self.uid_gen) - sql = self._compile_result_node(typing.cast(nodes.ResultNode, remap_node)) + result_node = self._remap_variables(result_node) + sql = self._compile_result_node(result_node) # Return the ordering iff no extra columns are needed to define the row order if ordering is not None: output_order = ( @@ -108,6 +110,14 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult sql, result_node.schema.to_bigquery(), output_order ) + def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode: + """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs.""" + + result_node, _ = rewrite.remap_variables( + node, map(identifiers.ColumnId, self.uid_gen.get_uid_stream("bfcol_")) + ) + return typing.cast(nodes.ResultNode, result_node) + def _compile_result_node(self, root: nodes.ResultNode) -> str: sqlglot_ir = self.compile_node(root.child) # TODO: add order_by, limit, and selections to sqlglot_expr diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 24eef41fda..660576670d 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -128,7 +128,7 @@ def _encapsulate_as_cte( existing_ctes = select_expr.args.pop("with", []) new_cte_name = sge.to_identifier( - self.uid_gen.generate_sequential_uid("bfcte_"), quoted=self.quoted + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ) new_cte = sge.CTE( this=select_expr, diff --git a/bigframes/core/guid.py b/bigframes/core/guid.py index cb3094c0e2..eae6f0a79c 100644 --- a/bigframes/core/guid.py +++ b/bigframes/core/guid.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + _GUID_COUNTER = 0 @@ -22,18 +24,20 @@ def generate_guid(prefix="col_"): class SequentialUIDGenerator: - """ - Generates sequential-like UIDs with multiple prefixes, e.g., "t0", "t1", "c0", "t2", etc. + """Produces a sequence of UIDs, such as {"t0", "t1", "c0", "t2", ...}, by + cycling through provided prefixes (e.g., "t" and "c"). + Note: this function is not thread-safe. """ def __init__(self): - self.prefix_counters = {} + self.prefix_counters: typing.Dict[str, int] = {} - def generate_sequential_uid(self, prefix: str) -> str: - """Generates a sequential UID with specified prefix.""" + def get_uid_stream(self, prefix: str) -> typing.Generator[str, None, None]: + """Yields a continuous stream of raw UID strings for the given prefix.""" if prefix not in self.prefix_counters: self.prefix_counters[prefix] = 0 - uid = f"{prefix}{self.prefix_counters[prefix]}" - self.prefix_counters[prefix] += 1 - return uid + while True: + uid = f"{prefix}{self.prefix_counters[prefix]}" + self.prefix_counters[prefix] += 1 + yield uid diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index e09ef2e519..0093e183b4 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -13,16 +13,19 @@ # limitations under the License. from __future__ import annotations -from typing import Tuple +import typing -from bigframes.core import guid, identifiers, nodes +from bigframes.core import identifiers, nodes # TODO: May as well just outright remove selection nodes in this process. def remap_variables( root: nodes.BigFrameNode, - uid_gen: guid.SequentialUIDGenerator, -) -> Tuple[nodes.BigFrameNode, dict[identifiers.ColumnId, identifiers.ColumnId],]: + id_generator: typing.Iterator[identifiers.ColumnId], +) -> typing.Tuple[ + nodes.BigFrameNode, + dict[identifiers.ColumnId, identifiers.ColumnId], +]: """Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs. Note: this will convert a DAG to a tree. @@ -31,7 +34,7 @@ def remap_variables( ref_mapping = dict() # Sequential ids are assigned bottom-up left-to-right for child in root.child_nodes: - new_child, child_var_mapping = remap_variables(child, uid_gen=uid_gen) + new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) child_replacement_map[child] = new_child ref_mapping.update(child_var_mapping) @@ -42,10 +45,7 @@ def remap_variables( with_new_refs = with_new_children.remap_refs(ref_mapping) - node_var_mapping = { - old_id: identifiers.ColumnId(name=uid_gen.generate_sequential_uid("bfcol_")) - for old_id in root.node_defined_ids - } + node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} with_new_vars = with_new_refs.remap_vars(node_var_mapping) with_new_vars._validate() diff --git a/tests/unit/core/compile/sqlglot/compiler_session.py b/tests/unit/core/compile/sqlglot/compiler_session.py index 67896e2e41..7309349681 100644 --- a/tests/unit/core/compile/sqlglot/compiler_session.py +++ b/tests/unit/core/compile/sqlglot/compiler_session.py @@ -18,7 +18,6 @@ import bigframes.core import bigframes.core.compile.sqlglot as sqlglot -import bigframes.core.guid import bigframes.dataframe import bigframes.session.executor import bigframes.session.metrics @@ -42,9 +41,9 @@ def to_sql( # Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips # caching the subtree. - return self.compiler.SQLGlotCompiler( - uid_gen=bigframes.core.guid.SequentialUIDGenerator() - ).compile(array_value.node, ordered=ordered) + return self.compiler.SQLGlotCompiler().compile( + array_value.node, ordered=ordered + ) class SQLCompilerSession(bigframes.session.Session):