From 0983a2dead51236a7d2cad4555f1d3f50d0b047e Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 23 Sep 2025 05:26:40 +0000 Subject: [PATCH] feat: add ai.generate_int to bigframes.bigquery package --- bigframes/bigquery/_operations/ai.py | 75 +++++++++++++++++++ .../ibis_compiler/scalar_op_registry.py | 38 +++++++--- .../compile/sqlglot/expressions/ai_ops.py | 51 +++++++++---- bigframes/operations/__init__.py | 3 +- bigframes/operations/ai_ops.py | 23 +++++- tests/system/small/bigquery/test_ai.py | 73 +++++++++++++----- .../test_ai_ops/test_ai_generate_int/out.sql | 18 +++++ .../out.sql | 18 +++++ .../sqlglot/expressions/test_ai_ops.py | 52 ++++++++++++- .../sql/compilers/bigquery/__init__.py | 9 ++- .../ibis/expr/operations/ai_ops.py | 19 +++++ 11 files changed, 332 insertions(+), 47 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 3bafce6166..f0b4f51611 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -113,6 +113,81 @@ def generate_bool( return series_list[0]._apply_nary_op(operator, series_list[1:]) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def generate_int( + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, + endpoint: str | None = None, + request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", + model_params: Mapping[Any, Any] | None = None, +) -> series.Series: + """ + Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"]) + >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")) + 0 {'result': 2, 'full_response': '{"candidates":... + 1 {'result': 4, 'full_response': '{"candidates":... + 2 {'result': 8, 'full_response': '{"candidates":... + dtype: struct>, status: string>[pyarrow] + + >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")).struct.field("result") + 0 2 + 1 4 + 2 8 + Name: result, dtype: Int64 + + Args: + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable + version of Gemini to use. + request_type (Literal["dedicated", "shared", "unspecified"]): + Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses. + * "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not + purchased or is not active if Provisioned Throughput quota isn't available. + * "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota. + * "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota. + If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. + If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. + model_params (Mapping[Any, Any]): + Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. + + Returns: + bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: + * "result": an integer (INT64) value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. + * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. + The generated text is in the text element. + * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. + """ + + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + operator = ai_ops.AIGenerateInt( + prompt_context=tuple(prompt_context), + connection_id=_resolve_connection_id(series_list[0], connection_id), + endpoint=endpoint, + request_type=request_type, + model_params=json.dumps(model_params) if model_params else None, + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + def _separate_context_and_series( prompt: PROMPT_TYPE, ) -> Tuple[List[str | None], List[series.Series]]: diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 8ffc556f76..a750a625ad 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1975,23 +1975,43 @@ def ai_generate_bool( *values: ibis_types.Value, op: ops.AIGenerateBool ) -> ibis_types.StructValue: + return ai_ops.AIGenerateBool( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.request_type.upper(), # type: ignore + op.model_params, # type: ignore + ).to_expr() + + +@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True) +def ai_generate_int( + *values: ibis_types.Value, op: ops.AIGenerateBool +) -> ibis_types.StructValue: + + return ai_ops.AIGenerateInt( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.request_type.upper(), # type: ignore + op.model_params, # type: ignore + ).to_expr() + + +def _construct_prompt( + col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None] +) -> ibis_types.StructValue: prompt: dict[str, ibis_types.Value | str] = {} column_ref_idx = 0 - for idx, elem in enumerate(op.prompt_context): + for idx, elem in enumerate(prompt_context): if elem is None: - prompt[f"_field_{idx + 1}"] = values[column_ref_idx] + prompt[f"_field_{idx + 1}"] = col_refs[column_ref_idx] column_ref_idx += 1 else: prompt[f"_field_{idx + 1}"] = elem - return ai_ops.AIGenerateBool( - ibis.struct(prompt), # type: ignore - op.connection_id, # type: ignore - op.endpoint, # type: ignore - op.request_type.upper(), # type: ignore - op.model_params, # type: ignore - ).to_expr() + return ibis.struct(prompt) @scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True) diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 8395461575..50d56611b1 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -14,6 +14,9 @@ from __future__ import annotations +from dataclasses import asdict +import typing + import sqlglot.expressions as sge from bigframes import operations as ops @@ -25,41 +28,61 @@ @register_nary_op(ops.AIGenerateBool, pass_op=True) def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.GENERATE_BOOL", *args) + + +@register_nary_op(ops.AIGenerateInt, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.GENERATE_INT", *args) + +def _construct_prompt( + exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...] +) -> sge.Kwarg: prompt: list[str | sge.Expression] = [] column_ref_idx = 0 - for elem in op.prompt_context: + for elem in prompt_context: if elem is None: prompt.append(exprs[column_ref_idx].expr) else: prompt.append(sge.Literal.string(elem)) - args = [sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))] + return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt)) + +def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: + args = [] + + op_args = asdict(op) + + connection_id = typing.cast(str, op_args["connection_id"]) args.append( - sge.Kwarg(this="connection_id", expression=sge.Literal.string(op.connection_id)) + sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id)) ) - if op.endpoint is not None: - args.append( - sge.Kwarg(this="endpoint", expression=sge.Literal.string(op.endpoint)) - ) + endpoit = typing.cast(str, op_args.get("endpoint", None)) + if endpoit is not None: + args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit))) + request_type = typing.cast(str, op_args["request_type"]).upper() args.append( - sge.Kwarg( - this="request_type", expression=sge.Literal.string(op.request_type.upper()) - ) + sge.Kwarg(this="request_type", expression=sge.Literal.string(request_type)) ) - if op.model_params is not None: + model_params = typing.cast(str, op_args.get("model_params", None)) + if model_params is not None: args.append( sge.Kwarg( this="model_params", - # sge.JSON requires a newer SQLGlot version than 23.6.3. + # sge.JSON requires the SQLGlot version to be at least 25.18.0 # PARSE_JSON won't work as the function requires a JSON literal. - expression=sge.JSON(this=sge.Literal.string(op.model_params)), + expression=sge.JSON(this=sge.Literal.string(model_params)), ) ) - return sge.func("AI.GENERATE_BOOL", *args) + return args diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index 6239b88e9e..17e1f7534f 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -14,7 +14,7 @@ from __future__ import annotations -from bigframes.operations.ai_ops import AIGenerateBool +from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt from bigframes.operations.array_ops import ( ArrayIndexOp, ArrayReduceOp, @@ -413,6 +413,7 @@ "GeoStDistanceOp", # AI ops "AIGenerateBool", + "AIGenerateInt", # Numpy ops mapping "NUMPY_TO_BINOP", "NUMPY_TO_OP", diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 680c1585fb..7a8202abd2 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -28,7 +28,6 @@ class AIGenerateBool(base_ops.NaryOp): name: ClassVar[str] = "ai_generate_bool" - # None are the placeholders for column references. prompt_context: Tuple[str | None, ...] connection_id: str endpoint: str | None @@ -45,3 +44,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT ) ) ) + + +@dataclasses.dataclass(frozen=True) +class AIGenerateInt(base_ops.NaryOp): + name: ClassVar[str] = "ai_generate_int" + + prompt_context: Tuple[str | None, ...] + connection_id: str + endpoint: str | None + request_type: Literal["dedicated", "shared", "unspecified"] + model_params: str | None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index be67a0d580..9f6feb0bbc 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - +from packaging import version import pandas as pd import pyarrow as pa import pytest +import sqlglot from bigframes import dtypes, series import bigframes.bigquery as bbq import bigframes.pandas as bpd -def test_ai_generate_bool(session): - s1 = bpd.Series(["apple", "bear"], session=session) +def test_ai_function_pandas_input(session): + s1 = pd.Series(["apple", "bear"]) s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) @@ -42,12 +42,20 @@ def test_ai_generate_bool(session): ) -def test_ai_generate_bool_with_pandas(session): - s1 = pd.Series(["apple", "bear"]) +def test_ai_function_compile_model_params(session): + if version.Version(sqlglot.__version__) < version.Version("25.18.0"): + pytest.skip( + "Skip test because SQLGLot cannot compile model params to JSON at this version." + ) + + s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) + model_params = {"generation_config": {"thinking_config": {"thinking_budget": 0}}} - result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash") + result = bbq.ai.generate_bool( + prompt, endpoint="gemini-2.5-flash", model_params=model_params + ) assert _contains_no_nulls(result) assert result.dtype == pd.ArrowDtype( @@ -61,20 +69,12 @@ def test_ai_generate_bool_with_pandas(session): ) -def test_ai_generate_bool_with_model_params(session): - if sys.version_info < (3, 12): - pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this env." - ) - +def test_ai_generate_bool(session): s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) - model_params = {"generation_config": {"thinking_config": {"thinking_budget": 0}}} - result = bbq.ai.generate_bool( - prompt, endpoint="gemini-2.5-flash", model_params=model_params - ) + result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash") assert _contains_no_nulls(result) assert result.dtype == pd.ArrowDtype( @@ -107,5 +107,44 @@ def test_ai_generate_bool_multi_model(session): ) +def test_ai_generate_int(session): + s = bpd.Series(["Cat"], session=session) + prompt = ("How many legs does a ", s, " have?") + + result = bbq.ai.generate_int(prompt, endpoint="gemini-2.5-flash") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_generate_int_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" + ) + + result = bbq.ai.generate_int( + ("How many animals are there in the picture ", df["image"]) + ) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + def _contains_no_nulls(s: series.Series) -> bool: return len(s) == s.count() diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql new file mode 100644 index 0000000000..e48b64bead --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_INT( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql new file mode 100644 index 0000000000..6f406dea18 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_INT( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + request_type => 'SHARED', + model_params => JSON '{}' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 15b9ae516b..33a257f9a9 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -13,9 +13,10 @@ # limitations under the License. import json -import sys +from packaging import version import pytest +import sqlglot from bigframes import dataframe from bigframes import operations as ops @@ -45,9 +46,9 @@ def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot): def test_ai_generate_bool_with_model_param( scalar_types_df: dataframe.DataFrame, snapshot ): - if sys.version_info < (3, 10): + if version.Version(sqlglot.__version__) < version.Version("25.18.0"): pytest.skip( - "Skip test because SQLGLot cannot compile model params to JSON at this env." + "Skip test because SQLGLot cannot compile model params to JSON at this version." ) col_name = "string_col" @@ -65,3 +66,48 @@ def test_ai_generate_bool_with_model_param( ) snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_int(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerateInt( + # The prompt does not make semantic sense but we only care about syntax correctness. + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_int_with_model_param( + scalar_types_df: dataframe.DataFrame, snapshot +): + if version.Version(sqlglot.__version__) < version.Version("25.18.0"): + pytest.skip( + "Skip test because SQLGLot cannot compile model params to JSON at this version." + ) + + col_name = "string_col" + + op = ops.AIGenerateInt( + # The prompt does not make semantic sense but we only care about syntax correctness. + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint=None, + request_type="shared", + model_params=json.dumps(dict()), + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 6ea11d5215..ef150534ee 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1105,9 +1105,14 @@ def visit_StringAgg(self, op, *, arg, sep, order_by, where): return self.agg.string_agg(expr, sep, where=where) def visit_AIGenerateBool(self, op, **kwargs): - func_name = "AI.GENERATE_BOOL" + return sge.func("AI.GENERATE_BOOL", *self._compile_ai_args(**kwargs)) + def visit_AIGenerateInt(self, op, **kwargs): + return sge.func("AI.GENERATE_INT", *self._compile_ai_args(**kwargs)) + + def _compile_ai_args(self, **kwargs): args = [] + for key, val in kwargs.items(): if val is None: continue @@ -1117,7 +1122,7 @@ def visit_AIGenerateBool(self, op, **kwargs): args.append(sge.Kwarg(this=sge.Identifier(this=key), expression=val)) - return sge.func(func_name, *args) + return args def visit_FirstNonNullValue(self, op, *, arg): return sge.IgnoreNulls(this=sge.FirstValue(this=arg)) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 1f8306bad6..4b855f71c0 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -30,3 +30,22 @@ def dtype(self) -> dt.Struct: return dt.Struct.from_tuples( (("result", dt.bool), ("full_resposne", dt.string), ("status", dt.string)) ) + + +@public +class AIGenerateInt(Value): + """Generate integers based on the prompt""" + + prompt: Value + connection_id: Value[dt.String] + endpoint: Optional[Value[dt.String]] + request_type: Value[dt.String] + model_params: Optional[Value[dt.String]] + + shape = rlz.shape_like("prompt") + + @attribute + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples( + (("result", dt.int64), ("full_resposne", dt.string), ("status", dt.string)) + )