8000 feat: add ai.generate_int to bigframes.bigquery package by sycai · Pull Request #2109 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,81 @@ def generate_bool(
return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_int(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
model_params: Mapping[Any, Any] | None = None,
) -> series.Series:
"""
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"])
>>> bbq.ai.generate_int(("How many legs does a ", animal, " have?"))
0 {'result': 2, 'full_response': '{"candidates":...
1 {'result': 4, 'full_response': '{"candidates":...
2 {'result': 8, 'full_response': '{"candidates":...
dtype: struct<result: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]

>>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")).struct.field("result")
0 2
1 4
2 8
Name: result, dtype: Int64

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
version of Gemini to use.
request_type (Literal["dedicated", "shared", "unspecified"]):
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
purchased or is not active if Provisioned Throughput quota isn't available.
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
model_params (Mapping[Any, Any]):
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.

Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": an integer (INT64) value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
The generated text is in the text element.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerateInt(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
endpoint=endpoint,
request_type=request_type,
model_params=json.dumps(model_params) if model_params else None,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


def _separate_context_and_series(
prompt: PROMPT_TYPE,
) -> Tuple[List[str | None], List[series.Series]]:
Expand Down
38 changes: 29 additions & 9 deletions bigframes/core/compile/ibis_compiler/scalar_op_registry.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -1975,23 +1975,43 @@ def ai_generate_bool(
*values: ibis_types.Value, op: ops.AIGenerateBool
) -> ibis_types.StructValue:

return ai_ops.AIGenerateBool(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True)
def ai_generate_int(
*values: ibis_types.Value, op: ops.AIGenerateBool
) -> ibis_types.StructValue:

return ai_ops.AIGenerateInt(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
).to_expr()


def _construct_prompt(
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
) -> ibis_types.StructValue:
prompt: dict[str, ibis_types.Value | str] = {}
column_ref_idx = 0

for idx, elem in enumerate(op.prompt_context):
for idx, elem in enumerate(prompt_context):
if elem is None:
prompt[f"_field_{idx + 1}"] = values[column_ref_idx]
prompt[f"_field_{idx + 1}"] = col_refs[column_ref_idx]
column_ref_idx += 1
else:
prompt[f"_field_{idx + 1}"] = elem

return ai_ops.AIGenerateBool(
ibis.struct(prompt), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
).to_expr()
return ibis.struct(prompt)


@scalar_op_compiler.register_nary_op(ops.RowKey, pass_op=True)
Expand Down
51 changes: 37 additions & 14 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from __future__ import annotations

from dataclasses import asdict
import typing

import sqlglot.expressions as sge

from bigframes import operations as ops
Expand All @@ -25,41 +28,61 @@

@register_nary_op(ops.AIGenerateBool, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.GENERATE_BOOL", *args)


@register_nary_op(ops.AIGenerateInt, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.GENERATE_INT", *args)


def _construct_prompt(
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
) -> sge.Kwarg:
prompt: list[str | sge.Expression] = []
column_ref_idx = 0

for elem in op.prompt_context:
for elem in prompt_context:
if elem is None:
prompt.append(exprs[column_ref_idx].expr)
else:
prompt.append(sge.Literal.string(elem))

args = [sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))]
return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))


def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
args = []

op_args = asdict(op)

connection_id = typing.cast(str, op_args["connection_id"])
args.append(
sge.Kwarg(this="connection_id", expression=sge.Literal.string(op.connection_id))
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
)

if op.endpoint is not None:
args.append(
sge.Kwarg(this="endpoint", expression=sge.Literal.string(op.endpoint))
)
endpoit = typing.cast(str, op_args.get("endpoint", None))
if endpoit is not None:
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))

request_type = typing.cast(str, op_args["request_type"]).upper()
args.append(
sge.Kwarg(
this="request_type", expression=sge.Literal.string(op.request_type.upper())
)
sge.Kwarg(this="request_type", expression=sge.Literal.string(request_type))
)

if op.model_params is not None:
model_params = typing.cast(str, op_args.get("model_params", None))
if model_params is not None:
args.append(
sge.Kwarg(
this="model_params",
# sge.JSON requires a newer SQLGlot version than 23.6.3.
# sge.JSON requires the SQLGlot version to be at least 25.18.0
# PARSE_JSON won't work as the function requires a JSON literal.
expression=sge.JSON(this=sge.Literal.string(op.model_params)),
expression=sge.JSON(this=sge.Literal.string(model_params)),
)
)

return sge.func("AI.GENERATE_BOOL", *args)
return args
3 changes: 2 additions & 1 deletion bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from bigframes.operations.ai_ops import AIGenerateBool
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt
from bigframes.operations.array_ops import (
ArrayIndexOp,
ArrayReduceOp,
Expand Down Expand Up @@ -413,6 +413,7 @@
"GeoStDistanceOp",
# AI ops
"AIGenerateBool",
"AIGenerateInt",
# Numpy ops mapping
"NUMPY_TO_BINOP",
"NUMPY_TO_OP",
Expand Down
23 changes: 22 additions & 1 deletion bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
class AIGenerateBool(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_bool"

# None are the placeholders for column references.
prompt_context: Tuple[str | None, ...]
connection_id: str
endpoint: str | None
Expand All @@ -45,3 +44,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
)
)
)


@dataclasses.dataclass(frozen=True)
class AIGenerateInt(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_int"

prompt_context: Tuple[str | None, ...]
connection_id: str
endpoint: str | None
request_type: Literal["dedicated", "shared", "unspecified"]
model_params: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.int64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)
73 changes: 56 additions & 17 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

from packaging import version
import pandas as pd
import pyarrow as pa
import pytest
import sqlglot

from bigframes import dtypes, series
import bigframes.bigquery as bbq
import bigframes.pandas as bpd


def test_ai_generate_bool(session):
s1 = bpd.Series(["apple", "bear"], session=session)
def test_ai_function_pandas_input(session):
s1 = pd.Series(["apple", "bear"])
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)

Expand All @@ -42,12 +42,20 @@ def test_ai_generate_bool(session):
)


def test_ai_generate_bool_with_pandas(session):
s1 = pd.Series(["apple", "bear"])
def test_ai_function_compile_model_params(session):
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
pytest.skip(
"Skip test because SQLGLot cannot compile model params to JSON at this version."
)

s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)
model_params = {"generation_config": {"thinking_config": {"thinking_budget": 0}}}

result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
result = bbq.ai.generate_bool(
prompt, endpoint="gemini-2.5-flash", model_params=model_params
)

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
Expand All @@ -61,20 +69,12 @@ def test_ai_generate_bool_with_pandas(session):
)


def test_ai_generate_bool_with_model_params(session):
if sys.version_info < (3, 12):
pytest.skip(
"Skip test because SQLGLot cannot compile model params to JSON at this env."
)

def test_ai_generate_bool(session):
s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)
model_params = {"generation_config": {"thinking_config": {"thinking_budget": 0}}}

result = bbq.ai.generate_bool(
prompt, endpoint="gemini-2.5-flash", model_params=model_params
)
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
Expand Down Expand Up @@ -107,5 +107,44 @@ def test_ai_generate_bool_multi_model(session):
)


def test_ai_generate_int(session):
s = bpd.Series(["Cat"], session=session)
prompt = ("How many legs does a ", s, " have?")

result = bbq.ai.generate_int(prompt, endpoint="gemini-2.5-flash")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.int64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def test_ai_generate_int_multi_model(session):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
)

result = bbq.ai.generate_int(
("How many animals are there in the picture ", df["image"])
)

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.int64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def _contains_no_nulls(s: series.Series) -> bool:
return len(s) == s.count()
Loading
0