From 3a5a06d2b1340a7bae8bd81437421f9c10e4058a Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 16 Jan 2025 21:23:43 +0000 Subject: [PATCH 1/3] refactor: Make window op node support non-unary ops --- bigframes/core/__init__.py | 3 +- bigframes/core/compile/aggregate_compiler.py | 3 +- bigframes/core/compile/compiled.py | 48 ++++++++++++-------- bigframes/core/compile/compiler.py | 3 +- bigframes/core/compile/polars/compiler.py | 26 +++++++++-- bigframes/core/nodes.py | 18 ++++---- bigframes/operations/aggregations.py | 2 +- 7 files changed, 65 insertions(+), 38 deletions(-) diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index ee9917f619..0bae094777 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -405,8 +405,7 @@ def project_window_op( ArrayValue( nodes.WindowOpNode( child=self.node, - column_name=ex.deref(column_name), - op=op, + expression=ex.UnaryAggregation(op, ex.deref(column_name)), window_spec=window_spec, output_name=ids.ColumnId(output_name), never_skip_nulls=never_skip_nulls, diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index f97856efa5..7a018a662e 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -479,10 +479,9 @@ def _( return _apply_window_if_present(column.dense_rank(), window) + 1 -@compile_unary_agg.register +@compile_nullary_agg.register def _( op: agg_ops.RowNumberOp, - column: ibis_types.Column, window=None, ) -> ibis_types.IntegerValue: return _apply_window_if_present(ibis_api.row_number(), window) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 526826495e..8da9251d49 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -861,8 +861,7 @@ def promote_offsets(self, col_id: str) -> OrderedIR: ## Methods that only work with ordering def project_window_op( self, - column_name: ex.DerefOp, - op: agg_ops.UnaryWindowOp, + expression: ex.Aggregation, window_spec: WindowSpec, output_name: str, *, @@ -881,8 +880,11 @@ def project_window_op( # See: https://github.com/ibis-project/ibis/issues/9773 used_exprs = map( self._compile_expression, - itertools.chain( - (column_name,), map(ex.DerefOp, window_spec.all_referenced_columns) + map( + ex.DerefOp, + itertools.chain( + expression.column_references, window_spec.all_referenced_columns + ), ), ) can_directly_window = not any( @@ -890,44 +892,54 @@ def project_window_op( ) if not can_directly_window: return self._reproject_to_table().project_window_op( - column_name, - op, + expression, window_spec, output_name, never_skip_nulls=never_skip_nulls, ) - column = typing.cast(ibis_types.Column, self._compile_expression(column_name)) window = self._ibis_window_from_spec( - window_spec, require_total_order=op.uses_total_row_ordering + window_spec, require_total_order=expression.op.uses_total_row_ordering ) bindings = {col: self._get_ibis_column(col) for col in self.column_ids} window_op = agg_compiler.compile_analytic( - ex.UnaryAggregation(op, column_name), + expression, window, bindings=bindings, ) + inputs = tuple( + typing.cast(ibis_types.Column, self._compile_expression(ex.DerefOp(column))) + for column in expression.column_references + ) clauses = [] - if op.skips_nulls and not never_skip_nulls: - clauses.append((column.isnull(), ibis_types.null())) - if window_spec.min_periods: - if op.skips_nulls: + if expression.op.skips_nulls and not never_skip_nulls: + for column in inputs: + clauses.append((column.isnull(), ibis_types.null())) + if window_spec.min_periods and len(inputs) > 0: + if expression.op.skips_nulls: # Most operations do not count NULL values towards min_periods + per_col_does_count = (column.notnull() for column in inputs) + # all inputs must be non-null for observation to count + is_observation = functools.reduce( + lambda x, y: x & y, per_col_does_count + ).astype(int) observation_count = agg_compiler.compile_analytic( - ex.UnaryAggregation(agg_ops.count_op, column_name), + ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")), window, - bindings=bindings, + bindings={"_observation_count": is_observation}, ) else: # Operations like count treat even NULLs as valid observations for the sake of min_periods # notnull is just used to convert null values to non-null (FALSE) values to be counted - denulled_value = typing.cast(ibis_types.BooleanColumn, column.notnull()) + is_observation = inputs[0].notnull() observation_count = agg_compiler.compile_analytic( - ex.UnaryAggregation(agg_ops.count_op, ex.deref("_denulled")), + ex.UnaryAggregation( + agg_ops.count_op, ex.deref("_observation_count") + ), window, - bindings={**bindings, "_denulled": denulled_value}, + bindings={"_observation_count": is_observation}, ) clauses.append( ( diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index 9e87b4b4e8..9548bb48f4 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -364,8 +364,7 @@ def compile_aggregate(self, node: nodes.AggregateNode, ordered: bool = True): @_compile_node.register def compile_window(self, node: nodes.WindowOpNode, ordered: bool = True): result = self.compile_ordered_ir(node.child).project_window_op( - node.column_name, - node.op, + node.expression, node.window_spec, node.output_name.sql, never_skip_nulls=node.never_skip_nulls, diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 7d8d54a7f0..6d5b11a5e8 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -16,7 +16,7 @@ import dataclasses import functools import itertools -from typing import cast, Sequence, TYPE_CHECKING +from typing import cast, Sequence, Tuple, TYPE_CHECKING import bigframes.core import bigframes.core.expression as ex @@ -125,6 +125,24 @@ def get_args( f"Aggregation {agg} not yet supported in polars engine." ) + def compile_agg_expr(self, expr: ex.Aggregation): + if isinstance(expr, ex.NullaryAggregation): + inputs: Tuple = () + elif isinstance(expr, ex.UnaryAggregation): + assert isinstance(expr.arg, ex.DerefOp) + inputs = (expr.arg.id.sql,) + elif isinstance(expr, ex.BinaryAggregation): + assert isinstance(expr.left, ex.DerefOp) + assert isinstance(expr.right, ex.DerefOp) + inputs = ( + expr.left.id.sql, + expr.right.id.sql, + ) + else: + raise ValueError(f"Unexpected aggregation: {expr.op}") + + return self.compile_agg_op(expr.op, inputs) + def compile_agg_op(self, op: agg_ops.WindowOp, inputs: Sequence[str] = []): if isinstance(op, agg_ops.ProductOp): # TODO: Need schema to cast back to original type if posisble (eg float back to int) @@ -320,9 +338,9 @@ def compile_sample(self, node: nodes.RandomSampleNode): @compile_node.register def compile_window(self, node: nodes.WindowOpNode): df = self.compile_node(node.child) - agg_expr = self.agg_compiler.compile_agg_op( - node.op, [node.column_name.id.sql] - ).alias(node.output_name.sql) + agg_expr = self.agg_compiler.compile_agg_expr(node.expression).alias( + node.output_name.sql + ) # Three window types: completely unbound, grouped and row bounded window = node.window_spec diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index fe79da2bf6..88d55ac70b 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -33,7 +33,6 @@ import bigframes.core.slices as slices import bigframes.core.window_spec as window import bigframes.dtypes -import bigframes.operations.aggregations as agg_ops if typing.TYPE_CHECKING: import bigframes.core.ordering as orderings @@ -1325,8 +1324,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): @dataclasses.dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode): - column_name: ex.DerefOp - op: agg_ops.UnaryWindowOp + expression: ex.Aggregation window_spec: window.WindowSpec output_name: bigframes.core.identifiers.ColumnId never_skip_nulls: bool = False @@ -1334,7 +1332,7 @@ class WindowOpNode(UnaryNode): def _validate(self): """Validate the local data in the node.""" - assert self.column_name.id in self.child.ids + assert all(ref in self.child.ids for ref in self.expression.column_references) @property def non_local(self) -> bool: @@ -1363,9 +1361,11 @@ def row_count(self) -> Optional[int]: @functools.cached_property def added_field(self) -> Field: - input_type = self.child.get_type(self.column_name.id) - new_item_dtype = self.op.output_type(input_type) - return Field(self.output_name, new_item_dtype) + input_types = self.child._dtype_lookup + return Field( + self.output_name, + bigframes.dtypes.dtype_for_etype(self.expression.output_type(input_types)), + ) @property def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: @@ -1376,7 +1376,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: return self.child.prune(used_cols) consumed_ids = ( used_cols.difference([self.output_name]) - .union([self.column_name.id]) + .union(self.expression.column_references) .union(self.window_spec.all_referenced_columns) ) return self.transform_children(lambda x: x.prune(consumed_ids)) @@ -1391,7 +1391,7 @@ def remap_vars( def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return dataclasses.replace( self, - column_name=self.column_name.remap_column_refs( + expression=self.expression.remap_column_refs( mappings, allow_partial_bindings=True ), window_spec=self.window_spec.remap_column_refs( diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 9de58fe5db..365b664ee0 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -381,7 +381,7 @@ def skips_nulls(self): # This should really by a NullaryWindowOp, but APIs don't support that yet. @dataclasses.dataclass(frozen=True) -class RowNumberOp(UnaryWindowOp): +class RowNumberOp(NullaryWindowOp): name: ClassVar[str] = "rownumber" @property From c93a89985c0ea2c874cf9e3ffed5c958ef60dcd1 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 17 Jan 2025 18:30:45 +0000 Subject: [PATCH 2/3] fix cast in window compile --- bigframes/core/compile/compiled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 8da9251d49..3b0181e976 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -924,7 +924,7 @@ def project_window_op( # all inputs must be non-null for observation to count is_observation = functools.reduce( lambda x, y: x & y, per_col_does_count - ).astype(int) + ).cast(int) observation_count = agg_compiler.compile_analytic( ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")), window, From 4200c44b0e4d1d35c18ae18efacbe29a738c7058 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 23 Jan 2025 00:19:01 +0000 Subject: [PATCH 3/3] capitalize comment --- bigframes/core/compile/compiled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 3b0181e976..ae5e2ff8c0 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -921,7 +921,7 @@ def project_window_op( if expression.op.skips_nulls: # Most operations do not count NULL values towards min_periods per_col_does_count = (column.notnull() for column in inputs) - # all inputs must be non-null for observation to count + # All inputs must be non-null for observation to count is_observation = functools.reduce( lambda x, y: x & y, per_col_does_count ).cast(int)