diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 785691edd6..a7f75e7264 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -86,9 +86,10 @@ def indicate_duplicates( # Discard this value if there are copies ANYWHERE window_spec = windows.unbound(grouping_keys=tuple(columns)) block, dummy = block.create_constant(1) + # use row number as will work even with partial ordering block, val_count_col_id = block.apply_window_op( dummy, - agg_ops.count_op, + agg_ops.sum_op, window_spec=window_spec, ) block, duplicate_indicator = block.project_expr( diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 482c38ae3d..f97856efa5 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -479,6 +479,15 @@ def _( return _apply_window_if_present(column.dense_rank(), window) + 1 +@compile_unary_agg.register +def _( + op: agg_ops.RowNumberOp, + column: ibis_types.Column, + window=None, +) -> ibis_types.IntegerValue: + return _apply_window_if_present(ibis_api.row_number(), window) + + @compile_unary_agg.register def _(op: agg_ops.FirstOp, column: ibis_types.Column, window=None) -> ibis_types.Value: return _apply_window_if_present(column.first(), window) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index d4c814145b..f879eb3feb 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -1330,7 +1330,7 @@ def _ibis_window_from_spec( if require_total_order or isinstance(window_spec.bounds, RowsWindowBounds): # Some operators need an unambiguous ordering, so the table's total ordering is appended order_by = tuple([*order_by, *self._ibis_order]) - elif isinstance(window_spec.bounds, RowsWindowBounds): + elif require_total_order or isinstance(window_spec.bounds, RowsWindowBounds): # If window spec has following or preceding bounds, we need to apply an unambiguous ordering. order_by = tuple(self._ibis_order) else: diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 6b7f56d708..9de58fe5db 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -379,6 +379,19 @@ def skips_nulls(self): return True +# This should really by a NullaryWindowOp, but APIs don't support that yet. +@dataclasses.dataclass(frozen=True) +class RowNumberOp(UnaryWindowOp): + name: ClassVar[str] = "rownumber" + + @property + def skips_nulls(self): + return False + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return dtypes.INT_DTYPE + + @dataclasses.dataclass(frozen=True) class RankOp(UnaryWindowOp): name: ClassVar[str] = "rank" diff --git a/tests/system/small/test_unordered.py b/tests/system/small/test_unordered.py index fe3411e266..106997f3e9 100644 --- a/tests/system/small/test_unordered.py +++ b/tests/system/small/test_unordered.py @@ -152,6 +152,27 @@ def test_unordered_merge(unordered_session): assert_pandas_df_equal(bf_result.to_pandas(), pd_result, ignore_order=True) +def test_unordered_drop_duplicates_ambiguous(unordered_session): + pd_df = pd.DataFrame( + {"a": [1, 1, 1], "b": [4, 4, 6], "c": [1, 1, 3]}, dtype=pd.Int64Dtype() + ) + bf_df = bpd.DataFrame(pd_df, session=unordered_session) + + # merge first to discard original ordering + bf_result = ( + bf_df.merge(bf_df, left_on="a", right_on="c") + .sort_values("c_y") + .drop_duplicates() + ) + pd_result = ( + pd_df.merge(pd_df, left_on="a", right_on="c") + .sort_values("c_y") + .drop_duplicates() + ) + + assert_pandas_df_equal(bf_result.to_pandas(), pd_result, ignore_order=True) + + def test_unordered_mode_cache_preserves_order(unordered_session): pd_df = pd.DataFrame( {"a": [1, 2, 3, 4, 5, 6], "b": [4, 5, 9, 3, 1, 6]}, dtype=pd.Int64Dtype() diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery.py deleted file mode 100644 index c090a1ca8f..0000000000 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery.py +++ /dev/null @@ -1,770 +0,0 @@ -"""Module to convert from Ibis expression to SQL string.""" - -from __future__ import annotations - -import re - -from bigframes_vendored.ibis import util -from bigframes_vendored.ibis.backends.sql.compilers.base import ( - NULL, - SQLGlotCompiler, - STAR, -) -from bigframes_vendored.ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType -from bigframes_vendored.ibis.backends.sql.rewrites import ( - exclude_unsupported_window_frame_from_ops, - exclude_unsupported_window_frame_from_rank, - exclude_unsupported_window_frame_from_row_number, -) -import bigframes_vendored.ibis.common.exceptions as com -from bigframes_vendored.ibis.common.temporal import ( - DateUnit, - IntervalUnit, - TimestampUnit, - TimeUnit, -) -import bigframes_vendored.ibis.expr.datatypes as dt -import bigframes_vendored.ibis.expr.operations as ops -import sqlglot as sg -from sqlglot.dialects import BigQuery -import sqlglot.expressions as sge - -_NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') - - -class BigQueryCompiler(SQLGlotCompiler): - dialect = BigQuery - type_mapper = BigQueryType - udf_type_mapper = BigQueryUDFType - rewrites = ( - exclude_unsupported_window_frame_from_ops, - exclude_unsupported_window_frame_from_row_number, - exclude_unsupported_window_frame_from_rank, - *SQLGlotCompiler.rewrites, - ) - - UNSUPPORTED_OPS = ( - ops.DateDiff, - ops.ExtractAuthority, - ops.ExtractUserInfo, - ops.FindInSet, - ops.Median, - ops.Quantile, - ops.MultiQuantile, - ops.RegexSplit, - ops.RowID, - ops.TimestampBucket, - ops.TimestampDiff, - ) - - NAN = sge.Cast( - this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) - ) - POS_INF = sge.Cast( - this=sge.convert("Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) - ) - NEG_INF = sge.Cast( - this=sge.convert("-Infinity"), to=sge.DataType(this=sge.DataType.Type.DOUBLE) - ) - - SIMPLE_OPS = { - ops.Arbitrary: "any_value", - ops.StringAscii: "ascii", - ops.BitAnd: "bit_and", - ops.BitOr: "bit_or", - ops.BitXor: "bit_xor", - ops.DateFromYMD: "date", - ops.Divide: "ieee_divide", - ops.EndsWith: "ends_with", - ops.GeoArea: "st_area", - ops.GeoAsBinary: "st_asbinary", - ops.GeoAsText: "st_astext", - ops.GeoAzimuth: "st_azimuth", - ops.GeoBuffer: "st_buffer", - ops.GeoCentroid: "st_centroid", - ops.GeoContains: "st_contains", - ops.GeoCoveredBy: "st_coveredby", - ops.GeoCovers: "st_covers", - ops.GeoDWithin: "st_dwithin", - ops.GeoDifference: "st_difference", - ops.GeoDisjoint: "st_disjoint", - ops.GeoDistance: "st_distance", - ops.GeoEndPoint: "st_endpoint", - ops.GeoEquals: "st_equals", - ops.GeoGeometryType: "st_geometrytype", - ops.GeoIntersection: "st_intersection", - ops.GeoIntersects: "st_intersects", - ops.GeoLength: "st_length", - ops.GeoMaxDistance: "st_maxdistance", - ops.GeoNPoints: "st_numpoints", - ops.GeoPerimeter: "st_perimeter", - ops.GeoPoint: "st_geogpoint", - ops.GeoPointN: "st_pointn", - ops.GeoStartPoint: "st_startpoint", - ops.GeoTouches: "st_touches", - ops.GeoUnaryUnion: "st_union_agg", - ops.GeoUnion: "st_union", - ops.GeoWithin: "st_within", - ops.GeoX: "st_x", - ops.GeoY: "st_y", - ops.Hash: "farm_fingerprint", - ops.IsInf: "is_inf", - ops.IsNan: "is_nan", - ops.Log10: "log10", - ops.LPad: "lpad", - ops.RPad: "rpad", - ops.Levenshtein: "edit_distance", - ops.Modulus: "mod", - ops.RegexReplace: "regexp_replace", - ops.RegexSearch: "regexp_contains", - ops.Time: "time", - ops.TimeFromHMS: "time_from_parts", - ops.TimestampNow: "current_timestamp", - ops.ExtractHost: "net.host", - } - - @staticmethod - def _minimize_spec(start, end, spec): - if ( - start is None - and isinstance(getattr(end, "value", None), ops.Literal) - and end.value.value == 0 - and end.following - ): - return None - return spec - - def visit_BoundingBox(self, op, *, arg): - name = type(op).__name__[len("Geo") :].lower() - return sge.Dot( - this=self.f.st_boundingbox(arg), expression=sg.to_identifier(name) - ) - - visit_GeoXMax = visit_GeoXMin = visit_GeoYMax = visit_GeoYMin = visit_BoundingBox - - def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): - if ( - not isinstance(op.preserve_collapsed, ops.Literal) - or op.preserve_collapsed.value - ): - raise com.UnsupportedOperationError( - "BigQuery simplify does not support preserving collapsed geometries, " - "pass preserve_collapsed=False" - ) - return self.f.st_simplify(arg, tolerance) - - def visit_ApproxMedian(self, op, *, arg, where): - return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)] - - def visit_Pi(self, op): - return self.f.acos(-1) - - def visit_E(self, op): - return self.f.exp(1) - - def visit_TimeDelta(self, op, *, left, right, part): - return self.f.time_diff(left, right, part, dialect=self.dialect) - - def visit_DateDelta(self, op, *, left, right, part): - return self.f.date_diff(left, right, part, dialect=self.dialect) - - def visit_TimestampDelta(self, op, *, left, right, part): - left_tz = op.left.dtype.timezone - right_tz = op.right.dtype.timezone - - if left_tz is None and right_tz is None: - return self.f.datetime_diff(left, right, part) - elif left_tz is not None and right_tz is not None: - return self.f.timestamp_diff(left, right, part) - - raise com.UnsupportedOperationError( - "timestamp difference with mixed timezone/timezoneless values is not implemented" - ) - - def visit_GroupConcat(self, op, *, arg, sep, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return self.f.string_agg(arg, sep) - - def visit_FloorDivide(self, op, *, left, right): - return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) - - def visit_Log2(self, op, *, arg): - return self.f.log(arg, 2, dialect=self.dialect) - - def visit_Log(self, op, *, arg, base): - if base is None: - return self.f.ln(arg) - return self.f.log(arg, base, dialect=self.dialect) - - def visit_ArrayRepeat(self, op, *, arg, times): - start = step = 1 - array_length = self.f.array_length(arg) - stop = self.f.greatest(times, 0) * array_length - i = sg.to_identifier("i") - idx = self.f.coalesce( - self.f.nullif(self.f.mod(i, array_length), 0), array_length - ) - series = self.f.generate_array(start, stop, step) - return self.f.array( - sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i)) - ) - - def visit_NthValue(self, op, *, arg, nth): - if not isinstance(op.nth, ops.Literal): - raise com.UnsupportedOperationError( - f"BigQuery `nth` must be a literal; got {type(op.nth)}" - ) - return self.f.nth_value(arg, nth) - - def visit_StrRight(self, op, *, arg, nchars): - return self.f.substr(arg, -self.f.least(self.f.length(arg), nchars)) - - def visit_StringJoin(self, op, *, arg, sep): - return self.f.array_to_string(self.f.array(*arg), sep) - - def visit_DayOfWeekIndex(self, op, *, arg): - return self.f.mod(self.f.extract(self.v.dayofweek, arg) + 5, 7) - - def visit_DayOfWeekName(self, op, *, arg): - return self.f.initcap(sge.Cast(this=arg, to="STRING FORMAT 'DAY'")) - - def visit_StringToTimestamp(self, op, *, arg, format_str): - if (timezone := op.dtype.timezone) is not None: - return self.f.parse_timestamp(format_str, arg, timezone) - return self.f.parse_datetime(format_str, arg) - - def visit_ArrayCollect(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return self.f.array_agg(sge.IgnoreNulls(this=arg)) - - def _neg_idx_to_pos(self, arg, idx): - return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) - - def visit_ArraySlice(self, op, *, arg, start, stop): - index = sg.to_identifier("bq_arr_slice") - cond = [index >= self._neg_idx_to_pos(arg, start)] - - if stop is not None: - cond.append(index < self._neg_idx_to_pos(arg, stop)) - - el = sg.to_identifier("el") - return self.f.array( - sg.select(el).from_(self._unnest(arg, as_=el, offset=index)).where(*cond) - ) - - def visit_ArrayIndex(self, op, *, arg, index): - return arg[self.f.safe_offset(index)] - - def visit_ArrayContains(self, op, *, arg, other): - name = sg.to_identifier(util.gen_name("bq_arr_contains")) - return sge.Exists( - this=sg.select(sge.convert(1)) - .from_(self._unnest(arg, as_=name)) - .where(name.eq(other)) - ) - - def visit_StringContains(self, op, *, haystack, needle): - return self.f.strpos(haystack, needle) > 0 - - def visti_StringFind(self, op, *, arg, substr, start, end): - if start is not None: - raise NotImplementedError( - "`start` not implemented for BigQuery string find" - ) - if end is not None: - raise NotImplementedError("`end` not implemented for BigQuery string find") - return self.f.strpos(arg, substr) - - def visit_TimestampFromYMDHMS( - self, op, *, year, month, day, hours, minutes, seconds - ): - return self.f.anon.DATETIME(year, month, day, hours, minutes, seconds) - - def visit_NonNullLiteral(self, op, *, value, dtype): - if dtype.is_inet() or dtype.is_macaddr(): - return sge.convert(str(value)) - elif dtype.is_timestamp(): - funcname = "DATETIME" if dtype.timezone is None else "TIMESTAMP" - return self.f.anon[funcname](value.isoformat()) - elif dtype.is_date(): - return self.f.date_from_parts(value.year, value.month, value.day) - elif dtype.is_time(): - time = self.f.time_from_parts(value.hour, value.minute, value.second) - if micros := value.microsecond: - # bigquery doesn't support `time(12, 34, 56.789101)`, AKA a - # float seconds specifier, so add any non-zero micros to the - # time value - return sge.TimeAdd( - this=time, expression=sge.convert(micros), unit=self.v.MICROSECOND - ) - return time - elif dtype.is_binary(): - return sge.Cast( - this=sge.convert(value.hex()), - to=sge.DataType(this=sge.DataType.Type.BINARY), - format=sge.convert("HEX"), - ) - elif dtype.is_interval(): - if dtype.unit == IntervalUnit.NANOSECOND: - raise com.UnsupportedOperationError( - "BigQuery does not support nanosecond intervals" - ) - elif dtype.is_uuid(): - return sge.convert(str(value)) - return None - - def visit_IntervalFromInteger(self, op, *, arg, unit): - if unit == IntervalUnit.NANOSECOND: - raise com.UnsupportedOperationError( - "BigQuery does not support nanosecond intervals" - ) - return sge.Interval(this=arg, unit=self.v[unit.singular]) - - def visit_Strftime(self, op, *, arg, format_str): - arg_dtype = op.arg.dtype - if arg_dtype.is_timestamp(): - if (timezone := arg_dtype.timezone) is None: - return self.f.format_datetime(format_str, arg) - else: - return self.f.format_timestamp(format_str, arg, timezone) - elif arg_dtype.is_date(): - return self.f.format_date(format_str, arg) - else: - assert arg_dtype.is_time(), arg_dtype - return self.f.format_time(format_str, arg) - - def visit_IntervalMultiply(self, op, *, left, right): - unit = self.v[op.left.dtype.resolution.upper()] - return sge.Interval(this=self.f.extract(unit, left) * right, unit=unit) - - def visit_TimestampFromUNIX(self, op, *, arg, unit): - unit = op.unit - if unit == TimestampUnit.SECOND: - return self.f.timestamp_seconds(arg) - elif unit == TimestampUnit.MILLISECOND: - return self.f.timestamp_millis(arg) - elif unit == TimestampUnit.MICROSECOND: - return self.f.timestamp_micros(arg) - elif unit == TimestampUnit.NANOSECOND: - return self.f.timestamp_micros( - self.cast(self.f.round(arg / 1_000), dt.int64) - ) - else: - raise com.UnsupportedOperationError(f"Unit not supported: {unit}") - - def visit_Cast(self, op, *, arg, to): - from_ = op.arg.dtype - if from_.is_timestamp() and to.is_integer(): - return self.f.unix_micros(arg) - elif from_.is_integer() and to.is_timestamp(): - return self.f.timestamp_seconds(arg) - elif from_.is_interval() and to.is_integer(): - if from_.unit in { - IntervalUnit.WEEK, - IntervalUnit.QUARTER, - IntervalUnit.NANOSECOND, - }: - raise com.UnsupportedOperationError( - f"BigQuery does not allow extracting date part `{from_.unit}` from intervals" - ) - return self.f.extract(self.v[to.resolution.upper()], arg) - elif from_.is_integer() and to.is_interval(): - return sge.Interval(this=arg, unit=self.v[to.unit.singular]) - elif from_.is_floating() and to.is_integer(): - return self.cast(self.f.trunc(arg), dt.int64) - return super().visit_Cast(op, arg=arg, to=to) - - def visit_JSONGetItem(self, op, *, arg, index): - return arg[index] - - def visit_UnwrapJSONString(self, op, *, arg): - return self.f.anon["safe.string"](arg) - - def visit_UnwrapJSONInt64(self, op, *, arg): - return self.f.anon["safe.int64"](arg) - - def visit_UnwrapJSONFloat64(self, op, *, arg): - return self.f.anon["safe.float64"](arg) - - def visit_UnwrapJSONBoolean(self, op, *, arg): - return self.f.anon["safe.bool"](arg) - - def visit_ExtractEpochSeconds(self, op, *, arg): - return self.f.unix_seconds(arg) - - def visit_ExtractWeekOfYear(self, op, *, arg): - return self.f.extract(self.v.isoweek, arg) - - def visit_ExtractIsoYear(self, op, *, arg): - return self.f.extract(self.v.isoyear, arg) - - def visit_ExtractMillisecond(self, op, *, arg): - return self.f.extract(self.v.millisecond, arg) - - def visit_ExtractMicrosecond(self, op, *, arg): - return self.f.extract(self.v.microsecond, arg) - - def visit_TimestampTruncate(self, op, *, arg, unit): - if unit == IntervalUnit.NANOSECOND: - raise com.UnsupportedOperationError( - f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" - ) - elif unit == IntervalUnit.WEEK: - unit = "WEEK(MONDAY)" - else: - unit = unit.name - return self.f.timestamp_trunc(arg, self.v[unit], dialect=self.dialect) - - def visit_DateTruncate(self, op, *, arg, unit): - if unit == DateUnit.WEEK: - unit = "WEEK(MONDAY)" - else: - unit = unit.name - return self.f.date_trunc(arg, self.v[unit], dialect=self.dialect) - - def visit_TimeTruncate(self, op, *, arg, unit): - if unit == TimeUnit.NANOSECOND: - raise com.UnsupportedOperationError( - f"BigQuery does not support truncating {op.arg.dtype} values to unit {unit!r}" - ) - else: - unit = unit.name - return self.f.time_trunc(arg, self.v[unit], dialect=self.dialect) - - def _nullifzero(self, step, zero, step_dtype): - if step_dtype.is_interval(): - return self.if_(step.eq(zero), NULL, step) - return self.f.nullif(step, zero) - - def _zero(self, dtype): - if dtype.is_interval(): - return self.f.make_interval() - return sge.convert(0) - - def _sign(self, value, dtype): - if dtype.is_interval(): - zero = self._zero(dtype) - return sge.Case( - ifs=[ - self.if_(value < zero, -1), - self.if_(value.eq(zero), 0), - self.if_(value > zero, 1), - ], - default=NULL, - ) - return self.f.sign(value) - - def _make_range(self, func, start, stop, step, step_dtype): - step_sign = self._sign(step, step_dtype) - delta_sign = self._sign(stop - start, step_dtype) - zero = self._zero(step_dtype) - nullifzero = self._nullifzero(step, zero, step_dtype) - condition = sg.and_(sg.not_(nullifzero.is_(NULL)), step_sign.eq(delta_sign)) - gen_array = func(start, stop, step) - name = sg.to_identifier(util.gen_name("bq_arr_range")) - inner = ( - sg.select(name) - .from_(self._unnest(gen_array, as_=name)) - .where(name.neq(stop)) - ) - return self.if_(condition, self.f.array(inner), self.f.array()) - - def visit_IntegerRange(self, op, *, start, stop, step): - return self._make_range(self.f.generate_array, start, stop, step, op.step.dtype) - - def visit_TimestampRange(self, op, *, start, stop, step): - if op.start.dtype.timezone is None or op.stop.dtype.timezone is None: - raise com.IbisTypeError( - "Timestamps without timezone values are not supported when generating timestamp ranges" - ) - return self._make_range( - self.f.generate_timestamp_array, start, stop, step, op.step.dtype - ) - - def visit_First(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - array = self.f.array_agg( - sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)), - ) - return array[self.f.safe_offset(0)] - - def visit_Last(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg))) - return array[self.f.safe_offset(0)] - - def visit_ArrayFilter(self, op, *, arg, body, param): - return self.f.array( - sg.select(param).from_(self._unnest(arg, as_=param)).where(body) - ) - - def visit_ArrayMap(self, op, *, arg, body, param): - return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param))) - - def visit_ArrayZip(self, op, *, arg): - lengths = [self.f.array_length(arr) - 1 for arr in arg] - idx = sg.to_identifier(util.gen_name("bq_arr_idx")) - indices = self._unnest( - self.f.generate_array(0, self.f.greatest(*lengths)), as_=idx - ) - struct_fields = [ - arr[self.f.safe_offset(idx)].as_(name) - for name, arr in zip(op.dtype.value_type.names, arg) - ] - return self.f.array( - sge.Select(kind="STRUCT", expressions=struct_fields).from_(indices) - ) - - def visit_ArrayPosition(self, op, *, arg, other): - name = sg.to_identifier(util.gen_name("bq_arr")) - idx = sg.to_identifier(util.gen_name("bq_arr_idx")) - unnest = self._unnest(arg, as_=name, offset=idx) - return self.f.coalesce( - sg.select(idx + 1).from_(unnest).where(name.eq(other)).limit(1).subquery(), - 0, - ) - - def _unnest(self, expression, *, as_, offset=None): - alias = sge.TableAlias(columns=[sg.to_identifier(as_)]) - return sge.Unnest(expressions=[expression], alias=alias, offset=offset) - - def visit_ArrayRemove(self, op, *, arg, other): - name = sg.to_identifier(util.gen_name("bq_arr")) - unnest = self._unnest(arg, as_=name) - return self.f.array(sg.select(name).from_(unnest).where(name.neq(other))) - - def visit_ArrayDistinct(self, op, *, arg): - name = util.gen_name("bq_arr") - return self.f.array( - sg.select(name).distinct().from_(self._unnest(arg, as_=name)) - ) - - def visit_ArraySort(self, op, *, arg): - name = util.gen_name("bq_arr") - return self.f.array( - sg.select(name).from_(self._unnest(arg, as_=name)).order_by(name) - ) - - def visit_ArrayUnion(self, op, *, left, right): - lname = util.gen_name("bq_arr_left") - rname = util.gen_name("bq_arr_right") - lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) - rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) - return self.f.array(sg.union(lhs, rhs, distinct=True)) - - def visit_ArrayIntersect(self, op, *, left, right): - lname = util.gen_name("bq_arr_left") - rname = util.gen_name("bq_arr_right") - lhs = sg.select(lname).from_(self._unnest(left, as_=lname)) - rhs = sg.select(rname).from_(self._unnest(right, as_=rname)) - return self.f.array(sg.intersect(lhs, rhs, distinct=True)) - - def visit_RegexExtract(self, op, *, arg, pattern, index): - matches = self.f.regexp_contains(arg, pattern) - nonzero_index_replace = self.f.regexp_replace( - arg, - self.f.concat(".*?", pattern, ".*"), - self.f.concat("\\", self.cast(index, dt.string)), - ) - zero_index_replace = self.f.regexp_replace( - arg, self.f.concat(".*?", self.f.concat("(", pattern, ")"), ".*"), "\\1" - ) - extract = self.if_(index.eq(0), zero_index_replace, nonzero_index_replace) - return self.if_(matches, extract, NULL) - - def visit_TimestampAddSub(self, op, *, left, right): - if not isinstance(right, sge.Interval): - raise com.OperationNotDefinedError( - "BigQuery does not support non-literals on the right side of timestamp add/subtract" - ) - if (unit := op.right.dtype.unit) == IntervalUnit.NANOSECOND: - raise com.UnsupportedOperationError( - f"BigQuery does not allow binary operation {type(op).__name__} with " - f"INTERVAL offset {unit}" - ) - - opname = type(op).__name__[len("Timestamp") :] - funcname = f"TIMESTAMP_{opname.upper()}" - return self.f.anon[funcname](left, right) - - visit_TimestampAdd = visit_TimestampSub = visit_TimestampAddSub - - def visit_DateAddSub(self, op, *, left, right): - if not isinstance(right, sge.Interval): - raise com.OperationNotDefinedError( - "BigQuery does not support non-literals on the right side of date add/subtract" - ) - if not (unit := op.right.dtype.unit).is_date(): - raise com.UnsupportedOperationError( - f"BigQuery does not allow binary operation {type(op).__name__} with " - f"INTERVAL offset {unit}" - ) - opname = type(op).__name__[len("Date") :] - funcname = f"DATE_{opname.upper()}" - return self.f.anon[funcname](left, right) - - visit_DateAdd = visit_DateSub = visit_DateAddSub - - def visit_Covariance(self, op, *, left, right, how, where): - if where is not None: - left = self.if_(where, left, NULL) - right = self.if_(where, right, NULL) - - if op.left.dtype.is_boolean(): - left = self.cast(left, dt.int64) - - if op.right.dtype.is_boolean(): - right = self.cast(right, dt.int64) - - how = op.how[:4].upper() - assert how in ("POP", "SAMP"), 'how not in ("POP", "SAMP")' - return self.agg[f"COVAR_{how}"](left, right, where=where) - - def visit_Correlation(self, op, *, left, right, how, where): - if how == "sample": - raise ValueError(f"Correlation with how={how!r} is not supported.") - - if where is not None: - left = self.if_(where, left, NULL) - right = self.if_(where, right, NULL) - - if op.left.dtype.is_boolean(): - left = self.cast(left, dt.int64) - - if op.right.dtype.is_boolean(): - right = self.cast(right, dt.int64) - - return self.agg.corr(left, right, where=where) - - def visit_TypeOf(self, op, *, arg): - return self._pudf("typeof", arg) - - def visit_Xor(self, op, *, left, right): - return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right)) - - def visit_HashBytes(self, op, *, arg, how): - if how not in ("md5", "sha1", "sha256", "sha512"): - raise NotImplementedError(how) - return self.f[how](arg) - - @staticmethod - def _gen_valid_name(name: str) -> str: - return "_".join(map(str.strip, _NAME_REGEX.findall(name))) or "tmp" - - def visit_CountStar(self, op, *, arg, where): - if where is not None: - return self.f.countif(where) - return self.f.count(STAR) - - def visit_CountDistinctStar(self, op, *, where, arg): - # Bigquery does not support count(distinct a,b,c) or count(distinct (a, b, c)) - # as expressions must be "groupable": - # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#group_by_grouping_item - # - # Instead, convert the entire expression to a string - # SELECT COUNT(DISTINCT concat(to_json_string(a), to_json_string(b))) - # This works with an array of datatypes which generates a unique string - # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_encodings - row = sge.Concat( - expressions=[ - self.f.to_json_string(sg.column(x, quoted=self.quoted)) - for x in op.arg.schema.keys() - ] - ) - if where is not None: - row = self.if_(where, row, NULL) - return self.f.count(sge.Distinct(expressions=[row])) - - def visit_Degrees(self, op, *, arg): - return self._pudf("degrees", arg) - - def visit_Radians(self, op, *, arg): - return self._pudf("radians", arg) - - def visit_CountDistinct(self, op, *, arg, where): - if where is not None: - arg = self.if_(where, arg, NULL) - return self.f.count(sge.Distinct(expressions=[arg])) - - def visit_RandomUUID(self, op, **kwargs): - return self.f.generate_uuid() - - def visit_ExtractFile(self, op, *, arg): - return self._pudf("cw_url_extract_file", arg) - - def visit_ExtractFragment(self, op, *, arg): - return self._pudf("cw_url_extract_fragment", arg) - - def visit_ExtractPath(self, op, *, arg): - return self._pudf("cw_url_extract_path", arg) - - def visit_ExtractProtocol(self, op, *, arg): - return self._pudf("cw_url_extract_protocol", arg) - - def visit_ExtractQuery(self, op, *, arg, key): - if key is not None: - return self._pudf("cw_url_extract_parameter", arg, key) - else: - return self._pudf("cw_url_extract_query", arg) - - def _pudf(self, name, *args): - name = sg.table(name, db="persistent_udfs", catalog="bigquery-public-data").sql( - self.dialect - ) - return self.f[name](*args) - - def visit_DropColumns(self, op, *, parent, columns_to_drop): - quoted = self.quoted - excludes = [sg.column(column, quoted=quoted) for column in columns_to_drop] - star = sge.Star(**{"except": excludes}) - table = sg.to_identifier(parent.alias_or_name, quoted=quoted) - column = sge.Column(this=star, table=table) - return sg.select(column).from_(parent) - - def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool - ): - quoted = self.quoted - - column_alias = sg.to_identifier( - util.gen_name("table_unnest_column"), quoted=quoted - ) - - selcols = [] - - table = sg.to_identifier(parent.alias_or_name, quoted=quoted) - - opname = op.column.name - overlaps_with_parent = opname in op.parent.schema - computed_column = column_alias.as_(opname, quoted=quoted) - - # replace the existing column if the unnested column hasn't been - # renamed - # - # e.g., table.unnest("x") - if overlaps_with_parent: - selcols.append( - sge.Column(this=sge.Star(replace=[computed_column]), table=table) - ) - else: - selcols.append(sge.Column(this=STAR, table=table)) - selcols.append(computed_column) - - if offset is not None: - offset = sg.to_identifier(offset, quoted=quoted) - selcols.append(offset) - - unnest = sge.Unnest( - expressions=[column], - alias=sge.TableAlias(columns=[column_alias]), - offset=offset, - ) - return ( - sg.select(*selcols) - .from_(parent) - .join(unnest, join_type="CROSS" if not keep_empty else "LEFT") - ) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 9de3e09540..3793a09229 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -252,17 +252,6 @@ def to_sqlglot( sources.append(result) return sources - @staticmethod - def _minimize_spec(start, end, spec): - if ( - start is None - and isinstance(getattr(end, "value", None), ops.Literal) - and end.value.value == 0 - and end.following - ): - return None - return spec - def visit_BoundingBox(self, op, *, arg): name = type(op).__name__[len("Geo") :].lower() return sge.Dot( @@ -1105,27 +1094,26 @@ def visit_Quantile(self, op, *, arg, quantile, where): def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by): # Patch for https://github.com/ibis-project/ibis/issues/9872 - if start is None and end is None: - spec = None - else: - if start is None: - start = {} - if end is None: - end = {} - start_value = start.get("value", "UNBOUNDED") - start_side = start.get("side", "PRECEDING") - end_value = end.get("value", "UNBOUNDED") - end_side = end.get("side", "FOLLOWING") + if start is None: + start = {} + if end is None: + end = {} - if getattr(start_value, "this", None) == "0": - start_value = "CURRENT ROW" - start_side = None + start_value = start.get("value", "UNBOUNDED") + start_side = start.get("side", "PRECEDING") + end_value = end.get("value", "UNBOUNDED") + end_side = end.get("side", "FOLLOWING") - if getattr(end_value, "this", None) == "0": - end_value = "CURRENT ROW" - end_side = None + if getattr(start_value, "this", None) == "0": + start_value = "CURRENT ROW" + start_side = None + if getattr(end_value, "this", None) == "0": + end_value = "CURRENT ROW" + end_side = None + + if how != "none": spec = sge.WindowSpec( kind=how.upper(), start=start_value, @@ -1134,7 +1122,12 @@ def visit_WindowFunction(self, op, *, how, func, start, end, group_by, order_by) end_side=end_side, over="OVER", ) - spec = self._minimize_spec(op.start, op.end, spec) + else: + spec = None + + # If unordered, unbound range window is implicit + if (not order_by) and (not start) and (not end): + spec = None order = sge.Order(expressions=order_by) if order_by else None diff --git a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py index 65119aa40a..b2ef6a15d3 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py @@ -400,25 +400,26 @@ def rewrite_empty_order_by_window(_, **kwargs): return _.copy(order_by=(ops.NULL,)) -@replace(p.WindowFunction(p.RowNumber | p.NTile)) +@replace(p.WindowFunction(p.RowNumber | p.NTile | p.MinRank | p.DenseRank)) def exclude_unsupported_window_frame_from_row_number(_, **kwargs): - return ops.Subtract(_.copy(start=None, end=0), 1) + # These functions do not support window bounds, only an ordering. + # Also, its kind of messy to insert subtract here, should probably be in visitor + return ops.Subtract( + _.copy(how="none", start=None, end=None, order_by=_.order_by or (ops.NULL,)), 1 + ) -@replace(p.WindowFunction(p.MinRank | p.DenseRank, start=None)) +@replace(p.WindowFunction(p.PercentRank | p.CumeDist, start=None)) def exclude_unsupported_window_frame_from_rank(_, **kwargs): - return ops.Subtract( - _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)), 1 - ) + # These functions do not support window bounds, only an ordering. + # Also, its kind of messy to insert subtract here, should probably be in visitor + return _.copy(how="none", start=None, end=None, order_by=_.order_by or (ops.NULL,)) -@replace( - p.WindowFunction( - p.Lag | p.Lead | p.PercentRank | p.CumeDist | p.Any | p.All, start=None - ) -) +@replace(p.WindowFunction(p.Lag | p.Lead, start=None)) def exclude_unsupported_window_frame_from_ops(_, **kwargs): - return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) + # lag/lead dont' support bounds, but do support ordering + return _.copy(how="none", start=None, end=None, order_by=_.order_by or (ops.NULL,)) # Rewrite rules for lowering a high-level operation into one composed of more diff --git a/third_party/bigframes_vendored/ibis/expr/operations/window.py b/third_party/bigframes_vendored/ibis/expr/operations/window.py index 0c9ae91fc7..0fcecb4109 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/window.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/window.py @@ -69,7 +69,8 @@ class WindowFunction(Value): """Window function operation.""" func: Analytic | Reduction - how: LiteralType["rows", "range"] = "rows" # noqa: F821 + # none is a hacky way to express that window bounds are not supported (eg row_number()) + how: LiteralType["rows", "range", "none"] = "rows" # noqa: F821 start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None group_by: VarTuple[Column] = () @@ -100,7 +101,7 @@ def __init__(self, how, start, end, **kwargs): raise com.IbisTypeError( "Window frame start and end boundaries must have the same datatype" ) - else: + elif how != "none": raise com.IbisTypeError( f"Window frame type must be either 'rows' or 'range', got {how}" )