8000 feat: add GeminiTextGenerator.predict structured output by GarrettWu · Pull Request #1653 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions bigframes/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

import abc
from typing import Callable, cast, Mapping, Optional, TypeVar, Union
from typing import cast, Optional, TypeVar, Union
import warnings

import bigframes_vendored.sklearn.base
Expand Down Expand Up @@ -244,18 +244,12 @@ def fit(


class RetriableRemotePredictor(BaseEstimator):
@property
@abc.abstractmethod
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
pass

@property
@abc.abstractmethod
def _status_col(self) -> str:
pass

def _predict_and_retry(
self, X: bpd.DataFrame, options: Mapping, max_retries: int
self,
bqml_model_predict_tvf: core.BqmlModel.TvfDef,
X: bpd.DataFrame,
options: dict,
max_retries: int,
) -> bpd.DataFrame:
assert self._bqml_model is not None

Expand All @@ -269,9 +263,9 @@ def _predict_and_retry(
warnings.warn(msg, category=RuntimeWarning)
break

df = self._predict_func(df_fail, options)
df = bqml_model_predict_tvf.tvf(self._bqml_model, df_fail, options)

success = df[self._status_col].str.len() == 0
success = df[bqml_model_predict_tvf.status_col].str.len() == 0
df_succ = df[success]
df_fail = df[~success]

Expand Down
31 changes: 29 additions & 2 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import dataclasses
import datetime
from typing import Callable, cast, Iterable, Mapping, Optional, Union
import uuid
Expand Down Expand Up @@ -44,6 +45,11 @@ class BqmlModel(BaseBqml):
BigQuery DataFrames ML.
"""

@dataclasses.dataclass
class TvfDef:
tvf: Callable[[BqmlModel, bpd.DataFrame, dict], bpd.DataFrame]
status_col: str

def __init__(self, session: bigframes.Session, model: bigquery.Model):
self._session = session
self._model = model
Expand Down Expand Up @@ -159,8 +165,9 @@ def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
def generate_text(
self,
input_data: bpd.DataFrame,
options: Mapping[str, int | float],
options: dict[str, Union[int, float, bool]],
) -> bpd.DataFrame:
options["flatten_json_output"] = True
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
Expand All @@ -169,11 +176,14 @@ def generate_text(
),
)

generate_text_tvf = TvfDef(generate_text, "ml_generate_text_status")

def generate_embedding(
self,
input_data: bpd.DataFrame,
options: Mapping[str, int | float],
options: dict[str, Union[int, float, bool]],
) -> bpd.DataFrame:
options["flatten_json_output"] = True
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_embedding(
Expand All @@ -182,6 +192,23 @@ def generate_embedding(
),
)

generate_embedding_tvf = TvfDef(generate_embedding, "ml_generate_embedding_status")

def generate_table(
self,
input_data: bpd.DataFrame,
options: dict[str, Union[int, float, bool, Mapping]],
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ai_generate_table(
source_sql=source_sql,
struct_options=options,
),
)

generate_table_tvf = TvfDef(generate_table, "status")

def detect_anomalies(
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
) -> bpd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion bigframes/ml/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_BASE_SQL_GENERATOR = sql.BaseSqlGenerator()
_BQML_MODEL_FACTORY = core.BqmlModelFactory()

_SUPPORTED_DTYPES = (
_REMOTE_MODEL_SUPPORTED_DTYPES = (
"bool",
"string",
"int64",
Expand Down
35 changes: 19 additions & 16 deletions bigframes/ml/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def __init__(
self,
model_path: str,
*,
input: Mapping[str, str] = {},
output: Mapping[str, str] = {},
input: Optional[Mapping[str, str]] = None,
output: Optional[Mapping[str, str]] = None,
session: Optional[bigframes.session.Session] = None,
):
self.session = session or bpd.get_global_session()
Expand All @@ -234,20 +234,23 @@ def _create_bqml_model(self):
return self._bqml_model_factory.create_imported_model(
session=self.session, options=options
)
else:
for io in (self.input, self.output):
for v in io.values():
if v not in globals._SUPPORTED_DTYPES:
raise ValueError(
f"field_type {v} is not supported. We only support {', '.join(globals._SUPPORTED_DTYPES)}."
)

return self._bqml_model_factory.create_xgboost_imported_model(
session=self.session,
input=self.input,
output=self.output,
options=options,
)
if not self.input or not self.output:
raise ValueError("input and output must both or neigher be set.")
self.input = {
k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES)
for k, v in self.input.items()
}
self.output = {
k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES)
for k, v in self.output.items()
}

return self._bqml_model_factory.create_xgboost_imported_model(
session=self.session,
input=self.input,
output=self.output,
options=options,
)

@classmethod
def _from_bq(
Expand Down
111 changes: 44 additions & 67 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union
from typing import cast, Iterable, Literal, Mapping, Optional, Union
import warnings

import bigframes_vendored.constants as constants
Expand Down Expand Up @@ -92,10 +92,6 @@
_CLAUDE_3_OPUS_ENDPOINT,
)


_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"

_MODEL_NOT_SUPPORTED_WARNING = (
"Model name '{model_name}' is not supported. "
"We are currently aware of the following models: {known_models}. "
Expand Down Expand Up @@ -193,18 +189,6 @@ def _from_bq(
model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_embedding

@property
def _status_col(self) -> str:
return _ML_GENERATE_EMBEDDING_STATUS

def predict(
self, X: utils.ArrayType, *, max_retries: int = 0
) -> bigframes.dataframe.DataFrame:
Expand Down Expand Up @@ -233,11 +217,14 @@ def predict(
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

options = {
"flatten_json_output": True,
}
options: dict = {}

return self._predict_and_retry(X, options=options, max_retries=max_retries)
return self._predict_and_retry(
core.BqmlModel.generate_embedding_tvf,
X,
options=options,
max_retries=max_retries,
)

def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
"""Save the model to BigQuery.
Expand Down Expand Up @@ -339,18 +326,6 @@ def _from_bq(
model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_embedding

@property
def _status_col(self) -> str:
return _ML_GENERATE_EMBEDDING_STATUS

def predict(
self, X: utils.ArrayType, *, max_retries: int = 0
) -> bigframes.dataframe.DataFrame:
Expand Down Expand Up @@ -384,11 +359,14 @@ def predict(
if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)

options = {
"flatten_json_output": True,
}
options: dict = {}

return self._predict_and_retry(X, options=options, max_retries=max_retries)
return self._predict_and_retry(
core.BqmlModel.generate_embedding_tvf,
X,
options=options,
max_retries=max_retries,
)

def to_gbq(
self, model_name: str, replace: bool = False
Expand Down Expand Up @@ -533,18 +511,6 @@ def _bqml_options(self) -> dict:
}
return options

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_text

@property
def _status_col(self) -> str:
return _ML_GENERATE_TEXT_STATUS

def fit(
self,
X: utils.ArrayType,
Expand Down Expand Up @@ -596,6 +562,7 @@ def predict(
ground_with_google_search: bool = False,
max_retries: int = 0,
prompt: Optional[Iterable[Union[str, bigframes.series.Series]]] = None,
output_schema: Optional[Mapping[str, str]] = None,
) -> bigframes.dataframe.DataFrame:
"""Predict the result from input DataFrame.

Expand Down Expand Up @@ -645,6 +612,9 @@ def predict(
Construct a prompt struct column for prediction based on the input. The input must be an Iterable that can take string literals,
such as "summarize", string column(s) of X, such as X["str_col"], or blob column(s) of X, such as X["blob_col"].
It creates a struct column of the items of the iterable, and use the concatenated result as the input prompt. No-op if set to None.
output_schema (Mapping[str, str] or None, default None):
The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>.
Supported types are int64, float64, bool and string. If None, output text result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
Expand Down Expand Up @@ -707,16 +677,31 @@ def predict(
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
options: dict = {
"temperature": temperature,
"max_output_tokens": max_output_tokens,
"top_k": top_k,
# "top_k": top_k, # TODO(garrettwu): the option is deprecated in Gemini 1.5 forward.
"top_p": top_p,
"flatten_json_output": True,
"ground_with_google_search": ground_with_google_search,
}
if output_schema:
output_schema = {
k: utils.standardize_type(v) for k, v in output_schema.items()
}
options["output_schema"] = output_schema
return self._predict_and_retry(
core.BqmlModel.generate_table_tvf,
X,
options=options,
max_retries=max_retries,
)

return self._predict_and_retry(X, options=options, max_retries=max_retries)
return self._predict_and_retry(
core.BqmlModel.generate_text_tvf,
X,
options=options,
max_retries=max_retries,
)

def score(
self,
Expand Down Expand Up @@ -916,18 +901,6 @@ def _bqml_options(self) -> dict:
}
return options

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_text

@property
def _status_col(self) -> str:
return _ML_GENERATE_TEXT_STATUS

def predict(
self,
X: utils.ArrayType,
Expand Down Expand Up @@ -1000,10 +973,14 @@ def predict(
"max_output_tokens": max_output_tokens,
"top_k": top_k,
"top_p": top_p,
"flatten_json_output": True,
}

return self._predict_and_retry(X, options=options, max_retries=max_retries)
return self._predict_and_retry(
core.BqmlModel.generate_text_tvf,
X,
options=options,
max_retries=max_retries,
)

def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
"""Save the model to BigQuery.
Expand Down
Loading
0