8000 refactor: Make window op node support non-unary ops by TrevorBergeron · Pull Request #1295 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,7 @@ def project_window_op(
ArrayValue(
nodes.WindowOpNode(
child=self.node,
column_name=ex.deref(column_name),
op=op,
expression=ex.UnaryAggregation(op, ex.deref(column_name)),
window_spec=window_spec,
output_name=ids.ColumnId(output_name),
never_skip_nulls=never_skip_nulls,
Expand Down
3 changes: 1 addition & 2 deletions bigframes/core/compile/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,9 @@ def _(
return _apply_window_if_present(column.dense_rank(), window) + 1


@compile_unary_agg.register
@compile_nullary_agg.register
def _(
op: agg_ops.RowNumberOp,
column: ibis_types.Column,
window=None,
) -> ibis_types.IntegerValue:
return _apply_window_if_present(ibis_api.row_number(), window)
Expand Down
48 changes: 30 additions & 18 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,7 @@ def promote_offsets(self, col_id: str) -> OrderedIR:
## Methods that only work with ordering
def project_window_op(
self,
column_name: ex.DerefOp,
op: agg_ops.UnaryWindowOp,
expression: ex.Aggregation,
window_spec: WindowSpec,
output_name: str,
*,
Expand All @@ -881,53 +880,66 @@ def project_window_op(
# See: https://github.com/ibis-project/ibis/issues/9773
used_exprs = map(
self._compile_expression,
itertools.chain(
(column_name,), map(ex.DerefOp, window_spec.all_referenced_columns)
map(
ex.DerefOp,
itertools.chain(
expression.column_references, window_spec.all_referenced_columns
),
),
)
can_directly_window = not any(
map(lambda x: is_literal(x) or is_window(x), used_exprs)
)
if not can_directly_window:
return self._reproject_to_table().project_window_op(
column_name,
op,
expression,
window_spec,
output_name,
never_skip_nulls=never_skip_nulls,
)

column = typing.cast(ibis_types.Column, self._compile_expression(column_name))
window = self._ibis_window_from_spec(
window_spec, require_total_order=op.uses_total_row_ordering
window_spec, require_total_order=expression.op.uses_total_row_ordering
)
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}

window_op = agg_compiler.compile_analytic(
ex.UnaryAggregation(op, column_name),
expression,
window,
bindings=bindings,
)

inputs = tuple(
typing.cast(ibis_types.Column, self._compile_expression(ex.DerefOp(column)))
for column in expression.column_references
)
clauses = []
if op.skips_nulls and not never_skip_nulls:
clauses.append((column.isnull(), ibis_types.null()))
if window_spec.min_periods:
if op.skips_nulls:
if expression.op.skips_nulls and not never_skip_nulls:
for column in inputs:
clauses.append((column.isnull(), ibis_types.null()))
if window_spec.min_periods and len(inputs) > 0:
if expression.op.skips_nulls:
# Most operations do not count NULL values towards min_periods
per_col_does_count = (column.notnull() for column in inputs)
# All inputs must be non-null for observation to count
is_observation = functools.reduce(
lambda x, y: x & y, per_col_does_count
).cast(int)
observation_count = agg_compiler.compile_analytic(
ex.UnaryAggregation(agg_ops.count_op, column_name),
ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")),
window,
bindings=bindings,
bindings={"_observation_count": is_observation},
)
else:
# Operations like count treat even NULLs as valid observations for the sake of min_periods
# notnull is just used to convert null values to non-null (FALSE) values to be counted
denulled_value = typing.cast(ibis_types.BooleanColumn, column.notnull())
is_observation = inputs[0].notnull()
observation_count = agg_compiler.compile_analytic(
ex.UnaryAggregation(agg_ops.count_op, ex.deref("_denulled")),
ex.UnaryAggregation(
agg_ops.count_op, ex.deref("_observation_count")
),
window,
bindings={**bindings, "_denulled": denulled_value},
bindings={"_observation_count": is_observation},
)
clauses.append(
(
Expand Down
3 changes: 1 addition & 2 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ def compile_aggregate(self, node: nodes.AggregateNode, ordered: bool = True):
@_compile_node.register
def compile_window(self, node: nodes.WindowOpNode, ordered: bool = True):
result = self.compile_ordered_ir(node.child).project_window_op(
node.column_name,
node.op,
node.expression,
node.window_spec,
node.output_name.sql,
never_skip_nulls=node.never_skip_nulls,
Expand Down
26 changes: 22 additions & 4 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
import functools
import itertools
from typing import cast, Sequence, TYPE_CHECKING
from typing import cast, Sequence, Tuple, TYPE_CHECKING

import bigframes.core
import bigframes.core.expression as ex
Expand Down Expand Up @@ -125,6 +125,24 @@ def get_args(
f"Aggregation {agg} not yet supported in polars engine."
)

def compile_agg_expr(self, expr: ex.Aggregation):
if isinstance(expr, ex.NullaryAggregation):
inputs: Tuple = ()
elif isinstance(expr, ex.UnaryAggregation):
assert isinstance(expr.arg, ex.DerefOp)
inputs = (expr.arg.id.sql,)
elif isinstance(expr, ex.BinaryAggregation):
assert isinstance(expr.left, ex.DerefOp)
assert isinstance(expr.right, ex.DerefOp)
inputs = (
expr.left.id.sql,
expr.right.id.sql,
)
else:
raise ValueError(f"Unexpected aggregation: {expr.op}")

return self.compile_agg_op(expr.op, inputs)

def compile_agg_op(self, op: agg_ops.WindowOp, inputs: Sequence[str] = []):
if isinstance(op, agg_ops.ProductOp):
# TODO: Need schema to cast back to original type if posisble (eg float back to int)
Expand Down Expand Up @@ -320,9 +338,9 @@ def compile_sample(self, node: nodes.RandomSampleNode):
@compile_node.register
def compile_window(self, node: nodes.WindowOpNode):
df = self.compile_node(node.child)
agg_expr = self.agg_compiler.compile_agg_op(
node.op, [node.column_name.id.sql]
).alias(node.output_name.sql)
agg_expr = self.agg_compiler.compile_agg_expr(node.expression).alias(
node.output_name.sql
)
# Three window types: completely unbound, grouped and row bounded

window = node.window_spec
Expand Down
18 changes: 9 additions & 9 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import bigframes.core.slices as slices
import bigframes.core.window_spec as window
import bigframes.dtypes
import bigframes.operations.aggregations as agg_ops

if typing.TYPE_CHECKING:
import bigframes.core.ordering as orderings
Expand Down Expand Up @@ -1325,16 +1324,15 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):

@dataclasses.dataclass(frozen=True, eq=False)
class WindowOpNode(UnaryNode):
column_name: ex.DerefOp
op: agg_ops.UnaryWindowOp
expression: ex.Aggregation
window_spec: window.WindowSpec
output_name: bigframes.core.identifiers.ColumnId
never_skip_nulls: bool = False
skip_reproject_unsafe: bool = False

def _validate(self):
"""Validate the local data in the node."""
assert self.column_name.id in self.child.ids
assert all(ref in self.child.ids for ref in self.expression.column_references)

@property
def non_local(self) -> bool:
Expand Down Expand Up @@ -1363,9 +1361,11 @@ def row_count(self) -> Optional[int]:

@functools.cached_property
def added_field(self) -> Field:
input_type = self.child.get_type(self.column_name.id)
new_item_dtype = self.op.output_type(input_type)
return Field(self.output_name, new_item_dtype)
input_types = self.child._dtype_lookup
return Field(
self.output_name,
bigframes.dtypes.dtype_for_etype(self.expression.output_type(input_types)),
)

@property
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
Expand All @@ -1376,7 +1376,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.child.prune(used_cols)
consumed_ids = (
used_cols.difference([self.output_name])
.union([self.column_name.id])
.union(self.expression.column_references)
.union(self.window_spec.all_referenced_columns)
)
return self.transform_children(lambda x: x.prune(consumed_ids))
Expand All @@ -1391,7 +1391,7 @@ def remap_vars(
def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
return dataclasses.replace(
self,
column_name=self.column_name.remap_column_refs(
expression=self.expression.remap_column_refs(
mappings, allow_partial_bindings=True
),
window_spec=self.window_spec.remap_column_refs(
Expand Down
2 changes: 1 addition & 1 deletion bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def skips_nulls(self):

# This should really by a NullaryWindowOp, but APIs don't support that yet.
@dataclasses.dataclass(frozen=True)
class RowNumberOp(UnaryWindowOp):
class RowNumberOp(NullaryWindowOp):
name: ClassVar[str] = "rownumber"

@property
Expand Down
Loading
0