From 9a185e2b5c623842b0f38c23db7695a3baa258bb Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Sat, 20 Jan 2024 00:03:24 +0000 Subject: [PATCH 1/3] refactor: add output type annotations to scalar ops --- bigframes/core/expression.py | 43 ++++- bigframes/dtypes.py | 63 ++++++- bigframes/functions/remote_function.py | 50 ++--- bigframes/operations/__init__.py | 245 +++++++++++++++++-------- bigframes/operations/type.py | 73 ++++++++ tests/unit/core/test_expression.py | 49 +++++ 6 files changed, 403 insertions(+), 120 deletions(-) create mode 100644 bigframes/operations/type.py create mode 100644 tests/unit/core/test_expression.py diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 540f9b6e5a..6472c222b2 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -23,11 +23,14 @@ import bigframes.dtypes import bigframes.operations +# Only constant null subexpressions should have no type +DtypeOrNoneType = Optional[bigframes.dtypes.Dtype] -def const( - value: typing.Hashable, dtype: Optional[bigframes.dtypes.Dtype] = None -) -> Expression: - return ScalarConstantExpression(value, dtype) + +def const(value: typing.Hashable, dtype: DtypeOrNoneType = None) -> Expression: + return ScalarConstantExpression( + value, dtype or bigframes.dtypes.infer_literal_type(value) + ) def free_var(id: str) -> Expression: @@ -45,9 +48,16 @@ def unbound_variables(self) -> typing.Tuple[str, ...]: def rename(self, name_mapping: dict[str, str]) -> Expression: return self - @abc.abstractproperty + @property + @abc.abstractmethod def is_const(self) -> bool: - return False + ... + + @abc.abstractmethod + def output_type( + self, input_types: dict[str, bigframes.dtypes.Dtype] + ) -> bigframes.dtypes.Dtype: + ... @dataclasses.dataclass(frozen=True) @@ -62,6 +72,11 @@ class ScalarConstantExpression(Expression): def is_const(self) -> bool: return True + def output_type( + self, input_types: dict[str, bigframes.dtypes.Dtype] + ) -> DtypeOrNoneType: + return self.dtype + @dataclasses.dataclass(frozen=True) class UnboundVariableExpression(Expression): @@ -83,6 +98,14 @@ def rename(self, name_mapping: dict[str, str]) -> Expression: def is_const(self) -> bool: return False + def output_type( + self, input_types: dict[str, bigframes.dtypes.Dtype] + ) -> DtypeOrNoneType: + if self.id in input_types: + return input_types[self.id] + else: + raise ValueError("Type of variable has not been fixed.") + @dataclasses.dataclass(frozen=True) class OpExpression(Expression): @@ -110,3 +133,11 @@ def rename(self, name_mapping: dict[str, str]) -> Expression: @property def is_const(self) -> bool: return all(child.is_const for child in self.inputs) + + def output_type( + self, input_types: dict[str, bigframes.dtypes.Dtype] + ) -> DtypeOrNoneType: + operand_types = tuple( + map(lambda x: x.output_type(input_types=input_types), self.inputs) + ) + return self.op.output_type(*operand_types) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 608885dec4..b81d430bb6 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -23,7 +23,9 @@ import geopandas as gpd # type: ignore import google.cloud.bigquery as bigquery import ibis +from ibis.backends.bigquery.datatypes import BigQueryType import ibis.expr.datatypes as ibis_dtypes +from ibis.expr.datatypes.core import dtype as python_type_to_bigquery_type import ibis.expr.types as ibis_types import numpy as np import pandas as pd @@ -42,6 +44,12 @@ pd.ArrowDtype, gpd.array.GeometryDtype, ] +ExpressionType = Union[Dtype, None] + +INT_DTYPE = pd.Int64Dtype() +FLOAT_DTYPE = pd.Float64Dtype() +BOOL_DTYPE = pd.BooleanDtype() +STRING_DTYPE = pd.StringDtype(storage="pyarrow") # On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable UNORDERED_DTYPES = [gpd.array.GeometryDtype()] @@ -539,20 +547,20 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]: return lcd_type(pd.Int64Dtype(), dtype) if isinstance(scalar, decimal.Decimal): # TODO: Check context to see if can use NUMERIC instead of BIGNUMERIC - return lcd_type(pd.ArrowDtype(pa.decimal128(76, 38)), dtype) + return lcd_type(pd.ArrowDtype(pa.decimal256(76, 38)), dtype) return None -def lcd_type(dtype1: Dtype, dtype2: Dtype) -> typing.Optional[Dtype]: +def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype: if dtype1 == dtype2: return dtype1 # Implicit conversion currently only supported for numeric types hierarchy: list[Dtype] = [ pd.BooleanDtype(), pd.Int64Dtype(), - pd.Float64Dtype(), pd.ArrowDtype(pa.decimal128(38, 9)), pd.ArrowDtype(pa.decimal256(76, 38)), + pd.Float64Dtype(), ] if (dtype1 not in hierarchy) or (dtype2 not in hierarchy): return None @@ -560,6 +568,14 @@ def lcd_type(dtype1: Dtype, dtype2: Dtype) -> typing.Optional[Dtype]: return hierarchy[lcd_index] +def lcd_etype(etype1: ExpressionType, etype2: ExpressionType) -> ExpressionType: + if etype1 is None: + return etype2 + if etype2 is None: + return etype1 + return lcd_type_or_throw(etype1, etype2) + + def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype: result = lcd_type(dtype1, dtype2) if result is None: @@ -567,3 +583,44 @@ def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype: f"BigFrames cannot upcast {dtype1} and {dtype2} to common type. {constants.FEEDBACK_LINK}" ) return result + + +def infer_literal_type(literal) -> typing.Optional[Dtype]: + if pd.isna(literal): + return None # Null value without a definite type + # Temporary logic, use ibis inferred type + ibis_literal = literal_to_ibis_scalar(literal) + return ibis_dtype_to_bigframes_dtype(ibis_literal.type()) + + +# Input and output types supported by BigQuery DataFrames remote functions. +# TODO(shobs): Extend the support to all types supported by BQ remote functions +# https://cloud.google.com/bigquery/docs/remote-functions#limitations +SUPPORTED_IO_PYTHON_TYPES = {bool, float, int, str} +SUPPORTED_IO_BIGQUERY_TYPEKINDS = { + "BOOLEAN", + "BOOL", + "FLOAT", + "FLOAT64", + "INT64", + "INTEGER", + "STRING", +} + + +class UnsupportedTypeError(ValueError): + def __init__(self, type_, supported_types): + self.type = type_ + self.supported_types = supported_types + + +def ibis_type_from_python_type(t: type) -> ibis_dtypes.DataType: + if t not in SUPPORTED_IO_PYTHON_TYPES: + raise UnsupportedTypeError(t, SUPPORTED_IO_PYTHON_TYPES) + return python_type_to_bigquery_type(t) + + +def ibis_type_from_type_kind(tk: bigquery.StandardSqlTypeNames) -> ibis_dtypes.DataType: + if tk not in SUPPORTED_IO_BIGQUERY_TYPEKINDS: + raise UnsupportedTypeError(tk, SUPPORTED_IO_BIGQUERY_TYPEKINDS) + return BigQueryType.to_ibis(tk) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index f54c26fa56..dfffbe65ac 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -46,12 +46,12 @@ from ibis.backends.bigquery.compiler import compiles from ibis.backends.bigquery.datatypes import BigQueryType from ibis.expr.datatypes.core import DataType as IbisDataType -from ibis.expr.datatypes.core import dtype as python_type_to_bigquery_type import ibis.expr.operations as ops import ibis.expr.rules as rlz from bigframes import clients import bigframes.constants as constants +import bigframes.dtypes logger = logging.getLogger(__name__) @@ -59,20 +59,6 @@ # https://docs.python.org/3/library/pickle.html#data-stream-format _pickle_protocol_version = 4 -# Input and output types supported by BigQuery DataFrames remote functions. -# TODO(shobs): Extend the support to all types supported by BQ remote functions -# https://cloud.google.com/bigquery/docs/remote-functions#limitations -SUPPORTED_IO_PYTHON_TYPES = {bool, float, int, str} -SUPPORTED_IO_BIGQUERY_TYPEKINDS = { - "BOOLEAN", - "BOOL", - "FLOAT", - "FLOAT64", - "INT64", - "INTEGER", - "STRING", -} - def get_remote_function_locations(bq_location): """Get BQ location and cloud functions region given a BQ client.""" @@ -558,24 +544,6 @@ def f(*args, **kwargs): return f -class UnsupportedTypeError(ValueError): - def __init__(self, type_, supported_types): - self.type = type_ - self.supported_types = supported_types - - -def ibis_type_from_python_type(t: type) -> IbisDataType: - if t not in SUPPORTED_IO_PYTHON_TYPES: - raise UnsupportedTypeError(t, SUPPORTED_IO_PYTHON_TYPES) - return python_type_to_bigquery_type(t) - - -def ibis_type_from_type_kind(tk: bigquery.StandardSqlTypeNames) -> IbisDataType: - if tk not in SUPPORTED_IO_BIGQUERY_TYPEKINDS: - raise UnsupportedTypeError(tk, SUPPORTED_IO_BIGQUERY_TYPEKINDS) - return BigQueryType.to_ibis(tk) - - def ibis_signature_from_python_signature( signature: inspect.Signature, input_types: Sequence[type], @@ -583,8 +551,10 @@ def ibis_signature_from_python_signature( ) -> IbisSignature: return IbisSignature( parameter_names=list(signature.parameters.keys()), - input_types=[ibis_type_from_python_type(t) for t in input_types], - output_type=ibis_type_from_python_type(output_type), + input_types=[ + bigframes.dtypes.ibis_type_from_python_type(t) for t in input_types + ], + output_type=bigframes.dtypes.ibis_type_from_python_type(output_type), ) @@ -599,10 +569,14 @@ def ibis_signature_from_routine(routine: bigquery.Routine) -> IbisSignature: return IbisSignature( parameter_names=[arg.name for arg in routine.arguments], input_types=[ - ibis_type_from_type_kind(arg.data_type.type_kind) if arg.data_type else None + bigframes.dtypes.ibis_type_from_type_kind(arg.data_type.type_kind) + if arg.data_type + else None for arg in routine.arguments ], - output_type=ibis_type_from_type_kind(routine.return_type.type_kind), + output_type=bigframes.dtypes.ibis_type_from_type_kind( + routine.return_type.type_kind + ), ) @@ -908,7 +882,7 @@ def read_gbq_function( raise ValueError( "Function return type must be specified. {constants.FEEDBACK_LINK}" ) - except UnsupportedTypeError as e: + except bigframes.dtypes.UnsupportedTypeError as e: raise ValueError( f"Type {e.type} not supported, supported types are {e.supported_types}. " f"{constants.FEEDBACK_LINK}" diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index 9737df94f9..e6ff31769a 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -20,6 +20,7 @@ import numpy as np import bigframes.dtypes as dtypes +import bigframes.operations.type as op_typing if typing.TYPE_CHECKING: # Avoids circular dependency @@ -36,6 +37,9 @@ def arguments(self) -> int: """The number of column argument the operation takes""" raise NotImplementedError("RowOp abstract base class has no implementation") + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + raise NotImplementedError("Abstract typing rule has no output type") + # These classes can be used to create simple ops that don't take local parameters # All is needed is a unique name, and to register an implementation in ibis_mappings.py @@ -49,6 +53,9 @@ def name(self) -> str: def arguments(self) -> int: return 1 + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + raise NotImplementedError("Abstract operation has no output type") + def as_expr( self, input_id: typing.Union[str, bigframes.core.expression.Expression] = "arg" ) -> bigframes.core.expression.Expression: @@ -69,6 +76,9 @@ def name(self) -> str: def arguments(self) -> int: return 2 + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + raise NotImplementedError("Abstract operation has no output type") + def as_expr( self, left_input: typing.Union[str, bigframes.core.expression.Expression] = "arg1", @@ -95,6 +105,9 @@ def name(self) -> str: def arguments(self) -> int: return 3 + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + raise NotImplementedError("Abstract operation has no output type") + def as_expr( self, input1: typing.Union[str, bigframes.core.expression.Expression] = "arg1", @@ -126,28 +139,43 @@ def _convert_expr_input( # Operation Factories -def create_unary_op(name: str) -> UnaryOp: +def create_unary_op( + name: str, type_rule: op_typing.OpTypeRule = op_typing.INPUT_TYPE +) -> UnaryOp: + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return type_rule.output_type(*input_types) + return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore bases=(UnaryOp,), frozen=True, )() -def create_binary_op(name: str) -> BinaryOp: +def create_binary_op( + name: str, type_rule: op_typing.OpTypeRule = op_typing.Supertype() +) -> BinaryOp: + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return type_rule.output_type(*input_types) + return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore bases=(BinaryOp,), frozen=True, )() -def create_ternary_op(name: str) -> TernaryOp: +def create_ternary_op( + name: str, type_rule: op_typing.OpTypeRule = op_typing.Supertype() +) -> TernaryOp: + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return type_rule.output_type(*input_types) + return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore bases=(TernaryOp,), frozen=True, )() @@ -155,57 +183,57 @@ def create_ternary_op(name: str) -> TernaryOp: # Unary Ops ## Generic Ops -invert_op = create_unary_op(name="invert") -isnull_op = create_unary_op(name="isnull") -notnull_op = create_unary_op(name="notnull") -hash_op = create_unary_op(name="hash") +invert_op = create_unary_op(name="invert", type_rule=op_typing.INPUT_TYPE) +isnull_op = create_unary_op(name="isnull", type_rule=op_typing.PREDICATE) +notnull_op = create_unary_op(name="notnull", type_rule=op_typing.PREDICATE) +hash_op = create_unary_op(name="hash", type_rule=op_typing.INTEGER) ## String Ops -len_op = create_unary_op(name="len") -reverse_op = create_unary_op(name="reverse") -lower_op = create_unary_op(name="lower") -upper_op = create_unary_op(name="upper") -strip_op = create_unary_op(name="strip") -isalnum_op = create_unary_op(name="isalnum") -isalpha_op = create_unary_op(name="isalpha") -isdecimal_op = create_unary_op(name="isdecimal") -isdigit_op = create_unary_op(name="isdigit") -isnumeric_op = create_unary_op(name="isnumeric") -isspace_op = create_unary_op(name="isspace") -islower_op = create_unary_op(name="islower") -isupper_op = create_unary_op(name="isupper") -rstrip_op = create_unary_op(name="rstrip") -lstrip_op = create_unary_op(name="lstrip") -capitalize_op = create_unary_op(name="capitalize") +len_op = create_unary_op(name="len", type_rule=op_typing.INTEGER) +reverse_op = create_unary_op(name="reverse", type_rule=op_typing.STRING) +lower_op = create_unary_op(name="lower", type_rule=op_typing.STRING) +upper_op = create_unary_op(name="upper", type_rule=op_typing.STRING) +strip_op = create_unary_op(name="strip", type_rule=op_typing.STRING) +isalnum_op = create_unary_op(name="isalnum", type_rule=op_typing.PREDICATE) +isalpha_op = create_unary_op(name="isalpha", type_rule=op_typing.PREDICATE) +isdecimal_op = create_unary_op(name="isdecimal", type_rule=op_typing.PREDICATE) +isdigit_op = create_unary_op(name="isdigit", type_rule=op_typing.PREDICATE) +isnumeric_op = create_unary_op(name="isnumeric", type_rule=op_typing.PREDICATE) +isspace_op = create_unary_op(name="isspace", type_rule=op_typing.PREDICATE) +islower_op = create_unary_op(name="islower", type_rule=op_typing.PREDICATE) +isupper_op = create_unary_op(name="isupper", type_rule=op_typing.PREDICATE) +rstrip_op = create_unary_op(name="rstrip", type_rule=op_typing.STRING) +lstrip_op = create_unary_op(name="lstrip", type_rule=op_typing.STRING) +capitalize_op = create_unary_op(name="capitalize", type_rule=op_typing.STRING) ## DateTime Ops -day_op = create_unary_op(name="day") -dayofweek_op = create_unary_op(name="dayofweek") +day_op = create_unary_op(name="day", type_rule=op_typing.INTEGER) +dayofweek_op = create_unary_op(name="dayofweek", type_rule=op_typing.INTEGER) date_op = create_unary_op(name="date") -hour_op = create_unary_op(name="hour") -minute_op = create_unary_op(name="minute") -month_op = create_unary_op(name="month") -quarter_op = create_unary_op(name="quarter") -second_op = create_unary_op(name="second") -time_op = create_unary_op(name="time") -year_op = create_unary_op(name="year") +hour_op = create_unary_op(name="hour", type_rule=op_typing.INTEGER) +minute_op = create_unary_op(name="minute", type_rule=op_typing.INTEGER) +month_op = create_unary_op(name="month", type_rule=op_typing.INTEGER) +quarter_op = create_unary_op(name="quarter", type_rule=op_typing.INTEGER) +second_op = create_unary_op(name="second", type_rule=op_typing.INTEGER) +time_op = create_unary_op(name="time", type_rule=op_typing.INTEGER) +year_op = create_unary_op(name="year", type_rule=op_typing.INTEGER) ## Trigonometry Ops -sin_op = create_unary_op(name="sin") -cos_op = create_unary_op(name="cos") -tan_op = create_unary_op(name="tan") -arcsin_op = create_unary_op(name="arcsin") -arccos_op = create_unary_op(name="arccos") -arctan_op = create_unary_op(name="arctan") -sinh_op = create_unary_op(name="sinh") -cosh_op = create_unary_op(name="cosh") -tanh_op = create_unary_op(name="tanh") -arcsinh_op = create_unary_op(name="arcsinh") -arccosh_op = create_unary_op(name="arccosh") -arctanh_op = create_unary_op(name="arctanh") +sin_op = create_unary_op(name="sin", type_rule=op_typing.REAL_NUMERIC) +cos_op = create_unary_op(name="cos", type_rule=op_typing.REAL_NUMERIC) +tan_op = create_unary_op(name="tan", type_rule=op_typing.REAL_NUMERIC) +arcsin_op = create_unary_op(name="arcsin", type_rule=op_typing.REAL_NUMERIC) +arccos_op = create_unary_op(name="arccos", type_rule=op_typing.REAL_NUMERIC) +arctan_op = create_unary_op(name="arctan", type_rule=op_typing.REAL_NUMERIC) +sinh_op = create_unary_op(name="sinh", type_rule=op_typing.REAL_NUMERIC) +cosh_op = create_unary_op(name="cosh", type_rule=op_typing.REAL_NUMERIC) +tanh_op = create_unary_op(name="tanh", type_rule=op_typing.REAL_NUMERIC) +arcsinh_op = create_unary_op(name="arcsinh", type_rule=op_typing.REAL_NUMERIC) +arccosh_op = create_unary_op(name="arccosh", type_rule=op_typing.REAL_NUMERIC) +arctanh_op = create_unary_op(name="arctanh", type_rule=op_typing.REAL_NUMERIC) ## Numeric Ops -abs_op = create_unary_op(name="abs") -exp_op = create_unary_op(name="exp") -ln_op = create_unary_op(name="log") -log10_op = create_unary_op(name="log10") -sqrt_op = create_unary_op(name="sqrt") +abs_op = create_unary_op(name="abs", type_rule=op_typing.INPUT_TYPE) +exp_op = create_unary_op(name="exp", type_rule=op_typing.REAL_NUMERIC) +ln_op = create_unary_op(name="log", type_rule=op_typing.REAL_NUMERIC) +log10_op = create_unary_op(name="log10", type_rule=op_typing.REAL_NUMERIC) +sqrt_op = create_unary_op(name="sqrt", type_rule=op_typing.REAL_NUMERIC) # Parameterized unary ops @@ -214,18 +242,27 @@ class StrContainsOp(UnaryOp): name: typing.ClassVar[str] = "str_contains" pat: str + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class StrContainsRegexOp(UnaryOp): name: typing.ClassVar[str] = "str_contains_regex" pat: str + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class StrGetOp(UnaryOp): name: typing.ClassVar[str] = "str_get" i: int + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class StrPadOp(UnaryOp): @@ -234,6 +271,9 @@ class StrPadOp(UnaryOp): fillchar: str side: typing.Literal["both", "left", "right"] + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class ReplaceStrOp(UnaryOp): @@ -241,6 +281,9 @@ class ReplaceStrOp(UnaryOp): pat: str repl: str + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class RegexReplaceStrOp(UnaryOp): @@ -248,24 +291,36 @@ class RegexReplaceStrOp(UnaryOp): pat: str repl: str + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class StartsWithOp(UnaryOp): name: typing.ClassVar[str] = "str_startswith" pat: typing.Sequence[str] + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class EndsWithOp(UnaryOp): name: typing.ClassVar[str] = "str_endswith" pat: typing.Sequence[str] + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class ZfillOp(UnaryOp): name: typing.ClassVar[str] = "str_zfill" width: int + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class StrFindOp(UnaryOp): @@ -274,6 +329,9 @@ class StrFindOp(UnaryOp): start: typing.Optional[int] end: typing.Optional[int] + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class StrExtractOp(UnaryOp): @@ -281,6 +339,9 @@ class StrExtractOp(UnaryOp): pat: str n: int = 1 + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class StrSliceOp(UnaryOp): @@ -288,12 +349,18 @@ class StrSliceOp(UnaryOp): start: typing.Optional[int] end: typing.Optional[int] + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + @dataclasses.dataclass(frozen=True) class StrRepeatOp(UnaryOp): name: typing.ClassVar[str] = "str_repeat" repeats: int + def output_type(self, *input_types): + return dtypes.STRING_DTYPE + # Other parameterized unary operations @dataclasses.dataclass(frozen=True) @@ -305,8 +372,14 @@ class StructFieldOp(UnaryOp): @dataclasses.dataclass(frozen=True) class AsTypeOp(UnaryOp): name: typing.ClassVar[str] = "astype" + # TODO: Convert strings to dtype earlier to_type: dtypes.DtypeString | dtypes.Dtype + def output_type(self, *input_types): + if isinstance(self.to_type, str): + return dtypes.BIGFRAMES_STRING_TO_BIGFRAMES[self.to_type] + return self.to_type + @dataclasses.dataclass(frozen=True) class IsInOp(UnaryOp): @@ -314,6 +387,9 @@ class IsInOp(UnaryOp): values: typing.Tuple match_nulls: bool = True + def output_type(self, *input_types): + return dtypes.BOOL_DTYPE + @dataclasses.dataclass(frozen=True) class RemoteFunctionOp(UnaryOp): @@ -321,12 +397,21 @@ class RemoteFunctionOp(UnaryOp): func: typing.Callable apply_on_null: bool + def output_type(self, *input_types): + python_type = self.func.__signature__.output_type + ibis_type = dtypes.ibis_type_from_python_type(python_type) + dtype = dtypes.ibis_dtype_to_bigframes_dtype(ibis_type) + return dtype + @dataclasses.dataclass(frozen=True) class MapOp(UnaryOp): name = "map_values" mappings: typing.Tuple[typing.Tuple[typing.Hashable, typing.Hashable], ...] + def output_type(self, *input_types): + return input_types[0] + # Binary Ops fillna_op = create_binary_op(name="fillna") @@ -334,34 +419,48 @@ class MapOp(UnaryOp): clipupper_op = create_binary_op(name="clip_upper") coalesce_op = create_binary_op(name="coalesce") ## Math Ops -add_op = create_binary_op(name="add") -sub_op = create_binary_op(name="sub") -mul_op = create_binary_op(name="mul") -div_op = create_binary_op(name="div") -floordiv_op = create_binary_op(name="floordiv") -pow_op = create_binary_op(name="pow") -mod_op = create_binary_op(name="mod") -round_op = create_binary_op(name="round") -unsafe_pow_op = create_binary_op(name="unsafe_pow_op") +add_op = create_binary_op(name="add", type_rule=op_typing.NUMERIC) +sub_op = create_binary_op(name="sub", type_rule=op_typing.NUMERIC) +mul_op = create_binary_op(name="mul", type_rule=op_typing.NUMERIC) +div_op = create_binary_op(name="div", type_rule=op_typing.REAL_NUMERIC) +floordiv_op = create_binary_op(name="floordiv", type_rule=op_typing.REAL_NUMERIC) +pow_op = create_binary_op(name="pow", type_rule=op_typing.REAL_NUMERIC) +mod_op = create_binary_op(name="mod", type_rule=op_typing.NUMERIC) +round_op = create_binary_op(name="round", type_rule=op_typing.REAL_NUMERIC) +unsafe_pow_op = create_binary_op(name="unsafe_pow_op", type_rule=op_typing.REAL_NUMERIC) # Logical Ops -and_op = create_binary_op(name="and") -or_op = create_binary_op(name="or") +and_op = create_binary_op(name="and", type_rule=op_typing.PREDICATE) +or_op = create_binary_op(name="or", type_rule=op_typing.PREDICATE) ## Comparison Ops -eq_op = create_binary_op(name="eq") -eq_null_match_op = create_binary_op(name="eq_nulls_match") -ne_op = create_binary_op(name="ne") -lt_op = create_binary_op(name="lt") -gt_op = create_binary_op(name="gt") -le_op = create_binary_op(name="le") -ge_op = create_binary_op(name="ge") +eq_op = create_binary_op(name="eq", type_rule=op_typing.PREDICATE) +eq_null_match_op = create_binary_op( + name="eq_nulls_match", type_rule=op_typing.PREDICATE +) +ne_op = create_binary_op(name="ne", type_rule=op_typing.PREDICATE) +lt_op = create_binary_op(name="lt", type_rule=op_typing.PREDICATE) +gt_op = create_binary_op(name="gt", type_rule=op_typing.PREDICATE) +le_op = create_binary_op(name="le", type_rule=op_typing.PREDICATE) +ge_op = create_binary_op(name="ge", type_rule=op_typing.PREDICATE) ## String Ops -strconcat_op = create_binary_op(name="strconcat") +strconcat_op = create_binary_op(name="strconcat", type_rule=op_typing.STRING) + # Ternary Ops -where_op = create_ternary_op(name="where") -clip_op = create_ternary_op(name="clip") +@dataclasses.dataclass(frozen=True) +class WhereOp(TernaryOp): + name: typing.ClassVar[str] = "where" + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + # Second input is boolean and doesn't affect output type + return dtypes.lcd_etype(input_types[0], input_types[2]) + + +where_op = WhereOp() + + +clip_op = create_ternary_op(name="clip", type_rule=op_typing.Supertype()) # Just parameterless unary ops for now diff --git a/bigframes/operations/type.py b/bigframes/operations/type.py new file mode 100644 index 0000000000..f63fabeba5 --- /dev/null +++ b/bigframes/operations/type.py @@ -0,0 +1,73 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools + +import pandas as pd + +import bigframes.dtypes +from bigframes.dtypes import ExpressionType + +# TODO: Apply input type constraints to help pre-empt invalid expression construction + + +@dataclasses.dataclass +class OpTypeRule: + def output_type(self, *input_types: ExpressionType) -> ExpressionType: + raise NotImplementedError("Abstract typing rule has no output type") + + +@dataclasses.dataclass +class InputType(OpTypeRule): + def output_type(self, *input_types: ExpressionType) -> ExpressionType: + assert len(input_types) == 1 + return input_types[0] + + +@dataclasses.dataclass +class RealNumeric(OpTypeRule): + def output_type(self, *input_types: ExpressionType) -> ExpressionType: + all_ints = all(pd.api.types.is_integer(input) for input in input_types) + if all_ints: + return bigframes.dtypes.FLOAT_DTYPE + else: + return functools.reduce( + lambda t1, t2: bigframes.dtypes.lcd_etype(t1, t2), input_types + ) + + +@dataclasses.dataclass +class Supertype(OpTypeRule): + def output_type(self, *input_types: ExpressionType) -> ExpressionType: + return functools.reduce( + lambda t1, t2: bigframes.dtypes.lcd_etype(t1, t2), input_types + ) + + +@dataclasses.dataclass +class Fixed(OpTypeRule): + out_type: ExpressionType + + def output_type(self, *input_types: ExpressionType) -> ExpressionType: + return self.out_type + + +# Common type rules +NUMERIC = Supertype() +REAL_NUMERIC = RealNumeric() +PREDICATE = Fixed(bigframes.dtypes.BOOL_DTYPE) +INTEGER = Fixed(bigframes.dtypes.INT_DTYPE) +STRING = Fixed(bigframes.dtypes.STRING_DTYPE) +INPUT_TYPE = InputType() diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py new file mode 100644 index 0000000000..f46c47a582 --- /dev/null +++ b/tests/unit/core/test_expression.py @@ -0,0 +1,49 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.core.expression as ex +import bigframes.dtypes as dtypes +import bigframes.operations as ops + + +def test_expression_dtype_simple(): + expression = ops.add_op.as_expr("a", "b") + result = expression.output_type({"a": dtypes.INT_DTYPE, "b": dtypes.INT_DTYPE}) + assert result == dtypes.INT_DTYPE + + +def test_expression_dtype_nested(): + expression = ops.add_op.as_expr( + "a", ops.abs_op.as_expr(ops.sub_op.as_expr("b", ex.const(3.14))) + ) + + result = expression.output_type({"a": dtypes.INT_DTYPE, "b": dtypes.INT_DTYPE}) + + assert result == dtypes.FLOAT_DTYPE + + +def test_expression_dtype_where(): + expression = ops.where_op.as_expr(ex.const(3), ex.const(True), ex.const(None)) + + result = expression.output_type({}) + + assert result == dtypes.INT_DTYPE + + +def test_expression_dtype_astype(): + expression = ops.AsTypeOp("Int64").as_expr(ex.const(3.14159)) + + result = expression.output_type({}) + + assert result == dtypes.INT_DTYPE From d6bc616e6b9b6f92eb3edfdb59a33ccf77c4c884 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 23 Jan 2024 23:02:04 +0000 Subject: [PATCH 2/3] use same expression type annotation everywhere --- bigframes/core/expression.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 6472c222b2..d1be644439 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -18,19 +18,13 @@ import dataclasses import itertools import typing -from typing import Optional -import bigframes.dtypes +import bigframes.dtypes as dtypes import bigframes.operations -# Only constant null subexpressions should have no type -DtypeOrNoneType = Optional[bigframes.dtypes.Dtype] - -def const(value: typing.Hashable, dtype: DtypeOrNoneType = None) -> Expression: - return ScalarConstantExpression( - value, dtype or bigframes.dtypes.infer_literal_type(value) - ) +def const(value: typing.Hashable, dtype: dtypes.ExpressionType = None) -> Expression: + return ScalarConstantExpression(value, dtype or dtypes.infer_literal_type(value)) def free_var(id: str) -> Expression: @@ -55,8 +49,8 @@ def is_const(self) -> bool: @abc.abstractmethod def output_type( - self, input_types: dict[str, bigframes.dtypes.Dtype] - ) -> bigframes.dtypes.Dtype: + self, input_types: dict[str, dtypes.ExpressionType] + ) -> dtypes.ExpressionType: ... @@ -66,7 +60,7 @@ class ScalarConstantExpression(Expression): # TODO: Further constrain? value: typing.Hashable - dtype: Optional[bigframes.dtypes.Dtype] = None + dtype: dtypes.ExpressionType = None @property def is_const(self) -> bool: @@ -74,7 +68,7 @@ def is_const(self) -> bool: def output_type( self, input_types: dict[str, bigframes.dtypes.Dtype] - ) -> DtypeOrNoneType: + ) -> dtypes.ExpressionType: return self.dtype @@ -100,7 +94,7 @@ def is_const(self) -> bool: def output_type( self, input_types: dict[str, bigframes.dtypes.Dtype] - ) -> DtypeOrNoneType: + ) -> dtypes.ExpressionType: if self.id in input_types: return input_types[self.id] else: @@ -135,8 +129,8 @@ def is_const(self) -> bool: return all(child.is_const for child in self.inputs) def output_type( - self, input_types: dict[str, bigframes.dtypes.Dtype] - ) -> DtypeOrNoneType: + self, input_types: dict[str, dtypes.ExpressionType] + ) -> dtypes.ExpressionType: operand_types = tuple( map(lambda x: x.output_type(input_types=input_types), self.inputs) ) From 03d0ce35164d1a226bfa908cabca306cc43bcb93 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 25 Jan 2024 00:03:51 +0000 Subject: [PATCH 3/3] pr comments --- bigframes/dtypes.py | 4 +++- bigframes/operations/__init__.py | 15 +++------------ bigframes/operations/type.py | 7 +++++++ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index b81d430bb6..cb2210bec6 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -44,7 +44,9 @@ pd.ArrowDtype, gpd.array.GeometryDtype, ] -ExpressionType = Union[Dtype, None] +# Represents both column types (dtypes) and local-only types +# None represents the type of a None scalar. +ExpressionType = typing.Optional[Dtype] INT_DTYPE = pd.Int64Dtype() FLOAT_DTYPE = pd.Float64Dtype() diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index e6ff31769a..b40f42a3e8 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -142,12 +142,9 @@ def _convert_expr_input( def create_unary_op( name: str, type_rule: op_typing.OpTypeRule = op_typing.INPUT_TYPE ) -> UnaryOp: - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return type_rule.output_type(*input_types) - return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], type_rule.as_method)], # type: ignore bases=(UnaryOp,), frozen=True, )() @@ -156,12 +153,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT def create_binary_op( name: str, type_rule: op_typing.OpTypeRule = op_typing.Supertype() ) -> BinaryOp: - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return type_rule.output_type(*input_types) - return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], type_rule.as_method)], # type: ignore bases=(BinaryOp,), frozen=True, )() @@ -170,12 +164,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT def create_ternary_op( name: str, type_rule: op_typing.OpTypeRule = op_typing.Supertype() ) -> TernaryOp: - def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - return type_rule.output_type(*input_types) - return dataclasses.make_dataclass( name, - [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], output_type)], # type: ignore + [("name", typing.ClassVar[str], name), ("output_type", typing.ClassVar[typing.Callable], type_rule.as_method)], # type: ignore bases=(TernaryOp,), frozen=True, )() diff --git a/bigframes/operations/type.py b/bigframes/operations/type.py index f63fabeba5..3c16f0cbe9 100644 --- a/bigframes/operations/type.py +++ b/bigframes/operations/type.py @@ -28,6 +28,13 @@ class OpTypeRule: def output_type(self, *input_types: ExpressionType) -> ExpressionType: raise NotImplementedError("Abstract typing rule has no output type") + @property + def as_method(self): + def meth(_, *input_types: ExpressionType) -> ExpressionType: + return self.output_type(*input_types) + + return meth + @dataclasses.dataclass class InputType(OpTypeRule):