10BC0 refactor: add compile_random_sample by chelsea-lin · Pull Request #1884 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def compile_explode(
columns = tuple(ref.id.sql for ref in node.column_ids)
return child.explode(columns, offsets_col)

@_compile_node.register
def compile_random_sample(
self, node: nodes.RandomSampleNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
return child.sample(node.fraction)


def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
Expand Down
192 changes: 129 additions & 63 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import guid
from bigframes.core import guid, utils
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.local_data as local_data
Expand Down Expand Up @@ -71,7 +71,10 @@ def from_pyarrow(
schema: bf_schema.ArraySchema,
uid_gen: guid.SequentialUIDGenerator,
) -> SQLGlotIR:
"""Builds SQLGlot expression from pyarrow table."""
"""Builds SQLGlot expression from a pyarrow table.

This is used to represent in-memory data as a SQL query.
"""
dtype_expr = sge.DataType(
this=sge.DataType.Type.STRUCT,
expressions=[
Expand Down Expand Up @@ -117,6 +120,16 @@ def from_table(
alias_names: typing.Sequence[str],
uid_gen: guid.SequentialUIDGenerator,
) -> SQLGlotIR:
"""Builds a SQLGlotIR expression from a BigQuery table.

Args:
project_id (str): The project ID of the BigQuery table.
dataset_id (str): The dataset ID of the BigQuery table.
table_id (str): The table ID of the BigQuery table.
col_names (typing.Sequence[str]): The names of the columns to select.
alias_names (typing.Sequence[str]): The aliases for the selected columns.
uid_gen (guid.SequentialUIDGenerator): A generator for unique identifiers.
"""
selections = [
sge.Alias(
this=sge.to_identifier(col_name, quoted=cls.quoted),
Expand All @@ -137,7 +150,7 @@ def from_query_string(
cls,
query_string: str,
) -> SQLGlotIR:
"""Builds SQLGlot expression from a query string"""
"""Builds a SQLGlot expression from a query string"""
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
cte_name = sge.to_identifier(
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
Expand All @@ -157,7 +170,7 @@ def from_union(
output_ids: typing.Sequence[str],
uid_gen: guid.SequentialUIDGenerator,
) -> SQLGlotIR:
"""Builds SQLGlot expression by union of multiple select expressions."""
"""Builds a SQLGlot expression by unioning of multiple select expressions."""
assert (
len(list(selects)) >= 2
), f"At least two select expressions must be provided, but got {selects}."
Expand Down Expand Up @@ -205,6 +218,7 @@ def select(
self,
selected_cols: tuple[tuple[str, sge.Expression], ...],
) -> SQLGlotIR:
"""Replaces new selected columns of the current SELECT clause."""
selections = [
sge.Alias(
this=expr,
Expand All @@ -213,15 +227,41 @@ def select(
for id, expr in selected_cols
]

new_expr, _ = self._encapsulate_as_cte()
new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
new_expr = new_expr.select(*selections, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def project(
self,
projected_cols: tuple[tuple[str, sge.Expression], ...],
) -> SQLGlotIR:
"""Adds new columns to the SELECT clause."""
projected_cols_expr = [
sge.Alias(
this=expr,
alias=sge.to_identifier(id, quoted=self.quoted),
)
for id, expr in projected_cols
]
new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
new_expr = new_expr.select(*projected_cols_expr, append=True)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def order_by(
self,
ordering: tuple[sge.Ordered, ...],
) -> SQLGlotIR:
"""Adds ORDER BY clause to the query."""
"""Adds an ORDER BY clause to the query."""
if len(ordering) == 0:
return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen)
new_expr = self.expr.order_by(*ordering)
Expand All @@ -231,34 +271,24 @@ def limit(
self,
limit: int | None,
) -> SQLGlotIR:
"""Adds LIMIT clause to the query."""
"""Adds a LIMIT clause to the query."""
if limit is not None:
new_expr = self.expr.limit(limit)
else:
new_expr = self.expr.copy()
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def project(
self,
projected_cols: tuple[tuple[str, sge.Expression], ...],
) -> SQLGlotIR:
projected_cols_expr = [
sge.Alias(
this=expr,
alias=sge.to_identifier(id, quoted=self.quoted),
)
for id, expr in projected_cols
]
new_expr, _ = self._encapsulate_as_cte()
new_expr = new_expr.select(*projected_cols_expr, append=True)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def filter(
self,
condition: sge.Expression,
) -> SQLGlotIR:
"""Filters the query with the given condition."""
new_expr, _ = self._encapsulate_as_cte()
"""Filters the query by adding a WHERE clause."""
new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
return SQLGlotIR(
expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen
)
Expand All @@ -272,8 +302,15 @@ def join(
joins_nulls: bool = True,
) -> SQLGlotIR:
"""Joins the current query with another SQLGlotIR instance."""
left_select, left_table = self._encapsulate_as_cte()
right_select, right_table = right._encapsulate_as_cte()
left_cte_name = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
)
right_cte_name = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
)

left_select = _select_to_cte(self.expr, left_cte_name)
right_select = _select_to_cte(right.expr, right_cte_name)

left_ctes = left_select.args.pop("with", [])
right_ctes = right_select.args.pop("with", [])
Expand All @@ -288,17 +325,50 @@ def join(
new_expr = (
sge.Select()
.select(sge.Star())
.from_(left_table)
.join(right_table, on=join_on, join_type=join_type_str)
.from_(sge.Table(this=left_cte_name))
.join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
)
new_expr.set("with", sge.With(expressions=merged_ctes))

return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def explode(
self,
column_names: tuple[str, ...],
offsets_col: typing.Optional[str],
) -> SQLGlotIR:
"""Unnests one or more array columns."""
num_columns = len(list(column_names))
assert num_columns > 0, "At least one column must be provided for explode."
if num_columns == 1:
return self._explode_single_column(column_names[0], offsets_col)
else:
return self._explode_multiple_columns(column_names, offsets_col)

def sample(self, fraction: float) -> SQLGlotIR:
"""Uniform samples a fraction of the rows."""
uuid_col = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted
)
uuid_expr = sge.Alias(this=sge.func("RAND"), alias=uuid_col)
condition = sge.LT(
this=uuid_col,
expression=_literal(fraction, dtypes.FLOAT_DTYPE),
)

new_cte_name = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
)
new_expr = _select_to_cte(
self.expr.select(uuid_expr, append=True), new_cte_name
).where(condition, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def insert(
self,
destination: bigquery.TableReference,
) -> str:
"""Generates an INSERT INTO SQL statement from the current SELECT clause."""
return sge.insert(self.expr.subquery(), _table(destination)).sql(
dialect=self.dialect, pretty=self.pretty
)
Expand All @@ -307,6 +377,9 @@ def replace(
self,
A258 destination: bigquery.TableReference,
) -> str:
"""Generates a MERGE statement to replace the destination table's contents.
by the current SELECT clause.
"""
# Workaround for SQLGlot breaking change:
# https://github.com/tobymao/sqlglot/pull/4495
whens_expr = [
Expand All @@ -325,23 +398,10 @@ def replace(
).sql(dialect=self.dialect, pretty=self.pretty)
return f"{merge_str}\n{whens_str}"

def explode(
self,
column_names: tuple[str, ...],
offsets_col: typing.Optional[str],
) -> SQLGlotIR:
num_columns = len(list(column_names))
assert num_columns > 0, "At least one column must be provided for explode."
if num_columns == 1:
return self._explode_single_column(column_names[0], offsets_col)
else:
return self._explode_multiple_columns(column_names, offsets_col)

def _explode_single_column(
self, column_name: str, offsets_col: typing.Optional[str]
) -> SQLGlotIR:
"""Helper method to handle the case of exploding a single column."""

offset = (
sge.to_identifier(offsets_col, quoted=self.quoted) if offsets_col else None
)
Expand All @@ -358,7 +418,12 @@ def _explode_single_column(

# TODO: "CROSS" if not keep_empty else "LEFT"
# TODO: overlaps_with_parent to replace existing column.
new_expr, _ = self._encapsulate_as_cte()
new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
new_expr = new_expr.select(selection, append=False).join(
unnest_expr, join_type="CROSS"
)
Expand Down Expand Up @@ -408,33 +473,32 @@ def _explode_multiple_columns(
for column in columns
]
)
new_expr, _ = self._encapsulate_as_cte()
new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
new_expr = new_expr.select(selection, append=False).join(
unnest_expr, join_type="CROSS"
)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def _encapsulate_as_cte(
self,
) -> typing.Tuple[sge.Select, sge.Table]:
"""Transforms a given sge.Select query by pushing its main SELECT statement
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
for the new query."""
select_expr = self.expr.copy()

existing_ctes = select_expr.args.pop("with", [])
new_cte_name = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
)
new_cte = sge.CTE(
this=select_expr,
alias=new_cte_name,
)
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
new_table_expr = sge.Table(this=new_cte_name)
new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr)
new_select_expr.set("with", new_with_clause)
return new_select_expr, new_table_expr
def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
"""Transforms a given sge.Select query by pushing its main SELECT statement
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
for the new query."""
select_expr = expr.copy()
existing_ctes = select_expr.args.pop("with", [])
new_cte = sge.CTE(
this=select_expr,
alias=cte_name,
)
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
new_select_expr.set("with", new_with_clause)
return new_select_expr


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
Expand All @@ -454,6 +518,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
elif dtype == dtypes.JSON_DTYPE:
return sge.ParseJSON(this=sge.convert(str(value)))
elif dtype == dtypes.TIMEDELTA_DTYPE:
return sge.convert(utils.timedelta_to_micros(value))
elif dtypes.is_struct_like(dtype):
items = [
_literal(value=value[field_name], dtype=field_dtype).as_(
Expand Down
2 changes: 2 additions & 0 deletions bigframes/core/compile/sqlglot/sqlglot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def from_bigframes_dtype(
return "JSON"
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
return "GEOGRAPHY"
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
return "INT64"
elif isinstance(bigframes_dtype, pd.ArrowDtype):
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/core/compile/sqlglot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from bigframes import dtypes
import bigframes.core as core
import bigframes.pandas as bpd
import bigframes.testing.mocks as mocks
import bigframes.testing.utils
Expand Down Expand Up @@ -115,6 +116,16 @@ def scalar_types_pandas_df() -> pd.DataFrame:
return df


@pytest.fixture(scope="module")
def scalar_types_array_value(
scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session
) -> core.ArrayValue:
managed_data_source = core.local_data.ManagedArrowTable.from_pandas(
scalar_types_pandas_df
)
return core.ArrayValue.from_managed(managed_data_source, compiler_session)


@pytest.fixture(scope="session")
def nested_structs_types_table_schema() -> typing.Sequence[bigquery.SchemaField]:
return [
Expand Down
Loading
0