diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 3893ad12d1..4759c99016 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -337,6 +337,100 @@ def generate_double( return series_list[0]._apply_nary_op(operator, series_list[1:]) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def if_( + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, +) -> series.Series: + """ + Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function + provides optimization such that not all rows are evaluated with the LLM. + + **Examples:** + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"]) + >>> bbq.ai.if_((us_state, " has a city called Springfield")) + 0 True + 1 True + 2 False + dtype: boolean + + >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))] + 0 Massachusetts + 1 Illinois + dtype: string + + Args: + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + + Returns: + bigframes.series.Series: A new series of bools. + """ + + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + operator = ai_ops.AIIf( + prompt_context=tuple(prompt_context), + connection_id=_resolve_connection_id(series_list[0], connection_id), + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def score( + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, +) -> series.Series: + """ + Computes a score based on rubrics described in natural language. It will return a double value. + There is no fixed range for the score returned. To get high quality results, provide a scoring + rubric with examples in the prompt. + + **Examples:** + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"]) + >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP + 0 2.0 + 1 1.0 + 2 3.0 + dtype: Float64 + + Args: + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + + Returns: + bigframes.series.Series: A new series of double (float) values. + """ + + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + operator = ai_ops.AIScore( + prompt_context=tuple(prompt_context), + connection_id=_resolve_connection_id(series_list[0], connection_id), + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + def _separate_context_and_series( prompt: PROMPT_TYPE, ) -> Tuple[List[str | None], List[series.Series]]: diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index a0750ec73d..7280e9a40a 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -2030,6 +2030,24 @@ def ai_generate_double( ).to_expr() +@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True) +def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: + + return ai_ops.AIIf( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + ).to_expr() + + +@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True) +def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue: + + return ai_ops.AIScore( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + ).to_expr() + + def _construct_prompt( col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None] ) -> ibis_types.StructValue: diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 3f909ebc92..46a79d1440 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -54,6 +54,20 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: return sge.func("AI.GENERATE_DOUBLE", *args) +@register_nary_op(ops.AIIf, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.IF", *args) + + +@register_nary_op(ops.AIScore, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.SCORE", *args) + + def _construct_prompt( exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...] ) -> sge.Kwarg: @@ -83,10 +97,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: if endpoit is not None: args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit))) - request_type = typing.cast(str, op_args["request_type"]).upper() - args.append( - sge.Kwarg(this="request_type", expression=sge.Literal.string(request_type)) - ) + request_type = typing.cast(str, op_args.get("request_type", None)) + if request_type is not None: + args.append( + sge.Kwarg( + this="request_type", expression=sge.Literal.string(request_type.upper()) + ) + ) model_params = typing.cast(str, op_args.get("model_params", None)) if model_params is not None: diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index 031e42cf03..e7d0751fc9 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -19,6 +19,8 @@ AIGenerateBool, AIGenerateDouble, AIGenerateInt, + AIIf, + AIScore, ) from bigframes.operations.array_ops import ( ArrayIndexOp, @@ -421,6 +423,8 @@ "AIGenerateBool", "AIGenerateDouble", "AIGenerateInt", + "AIIf", + "AIScore", # Numpy ops mapping "NUMPY_TO_BINOP", "NUMPY_TO_OP", diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 5d710bf6b5..05d37d2a90 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -110,3 +110,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT ) ) ) + + +@dataclasses.dataclass(frozen=True) +class AIIf(base_ops.NaryOp): + name: ClassVar[str] = "ai_if" + + prompt_context: Tuple[str | None, ...] + connection_id: str + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return dtypes.BOOL_DTYPE + + +@dataclasses.dataclass(frozen=True) +class AIScore(base_ops.NaryOp): + name: ClassVar[str] = "ai_score" + + prompt_context: Tuple[str | None, ...] + connection_id: str + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return dtypes.FLOAT_DTYPE diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 890cd4fb2b..91499d0efe 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -203,5 +203,49 @@ def test_ai_generate_double_multi_model(session): ) +def test_ai_if(session): + s1 = bpd.Series(["apple", "bear"], session=session) + s2 = bpd.Series(["fruit", "tree"], session=session) + prompt = (s1, " is a ", s2) + + result = bbq.ai.if_(prompt) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.BOOL_DTYPE + + +def test_ai_if_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" + ) + + result = bbq.ai.if_((df["image"], " contains an animal")) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.BOOL_DTYPE + + +def test_ai_score(session): + s = bpd.Series(["Tiger", "Rabbit"], session=session) + prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3") + + result = bbq.ai.score(prompt) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.FLOAT_DTYPE + + +def test_ai_score_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" + ) + prompt = ("Rank the liveliness of ", df["image"], "on the scale from 1 to 3") + + result = bbq.ai.score(prompt) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.FLOAT_DTYPE + + def _contains_no_nulls(s: series.Series) -> bool: return len(s) == s.count() diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql new file mode 100644 index 0000000000..d5b6a9330d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.IF( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql new file mode 100644 index 0000000000..e2be615921 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql @@ -0,0 +1,16 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.SCORE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index f20b39bc74..8f048a5bbf 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -199,3 +199,33 @@ def test_ai_generate_double_with_model_param( ) snapshot.assert_match(sql, "out.sql") + + +def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIIf( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIScore( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 836d15118c..8603c89cc8 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1116,6 +1116,12 @@ def visit_AIGenerateInt(self, op, **kwargs): def visit_AIGenerateDouble(self, op, **kwargs): return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs)) + def visit_AIIf(self, op, **kwargs): + return sge.func("AI.IF", *self._compile_ai_args(**kwargs)) + + def visit_AIScore(self, op, **kwargs): + return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs)) + def _compile_ai_args(self, **kwargs): args = [] diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 05c5e7e0af..5289ee7e60 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -91,3 +91,31 @@ def dtype(self) -> dt.Struct: ("status", dt.string), ) ) + + +@public +class AIIf(Value): + """Generate True/False based on the prompt""" + + prompt: Value + connection_id: Value[dt.String] + + shape = rlz.shape_like("prompt") + + @attribute + def dtype(self) -> dt.Struct: + return dt.bool + + +@public +class AIScore(Value): + """Generate doubles based on the prompt""" + + prompt: Value + connection_id: Value[dt.String] + + shape = rlz.shape_like("prompt") + + @attribute + def dtype(self) -> dt.Struct: + return dt.float64