diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 93e2ba825f..78f3369daf 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -17,6 +17,7 @@ from __future__ import annotations from typing import cast, Literal, Optional, Union +import warnings import bigframes from bigframes import clients, constants @@ -24,15 +25,22 @@ from bigframes.ml import base, core, globals, utils import bigframes.pandas as bpd -_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT = "text-bison" -_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT = "text-bison-32k" -_TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result" +_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison" +_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k" +_TEXT_GENERATOR_ENDPOINTS = ( + _TEXT_GENERATOR_BISON_ENDPOINT, + _TEXT_GENERATOR_BISON_32K_ENDPOINT, +) -_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT = "textembedding-gecko" -_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT = ( - "textembedding-gecko-multilingual" +_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko" +_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual" +_EMBEDDING_GENERATOR_ENDPOINTS = ( + _EMBEDDING_GENERATOR_GECKO_ENDPOINT, + _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT, ) -_EMBED_TEXT_RESULT_COLUMN = "text_embedding" + +_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" +_ML_EMBED_TEXT_STATUS = "ml_embed_text_status" class PaLM2TextGenerator(base.Predictor): @@ -90,18 +98,16 @@ def _create_bqml_model(self): connection_id=connection_name_parts[2], iam_role="aiplatform.user", ) - if self.model_name == _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT: - options = { - "endpoint": _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT, - } - elif self.model_name == _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT: - options = { - "endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT, - } - else: + + if self.model_name not in _TEXT_GENERATOR_ENDPOINTS: raise ValueError( - f"Model name {self.model_name} is not supported. We only support {_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT}." + f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}." ) + + options = { + "endpoint": self.model_name, + } + return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, options=options ) @@ -182,7 +188,16 @@ def predict( "top_p": top_p, "flatten_json_output": True, } - return self._bqml_model.generate_text(X, options) + + df = self._bqml_model.generate_text(X, options) + + if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): + warnings.warn( + f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + return df class PaLM2TextEmbeddingGenerator(base.Predictor): @@ -241,19 +256,15 @@ def _create_bqml_model(self): connection_id=connection_name_parts[2], iam_role="aiplatform.user", ) - if self.model_name == "textembedding-gecko": - options = { - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT, - } - elif self.model_name == _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT: - options = { - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT, - } - else: + + if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: raise ValueError( - f"Model name {self.model_name} is not supported. We only support {_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT}." + f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}." ) + options = { + "endpoint": self.model_name, + } return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, options=options ) @@ -284,4 +295,13 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: options = { "flatten_json_output": True, } - return self._bqml_model.generate_text_embedding(X, options) + + df = self._bqml_model.generate_text_embedding(X, options) + + if (df[_ML_EMBED_TEXT_STATUS] != "").any(): + warnings.warn( + f"Some predictions failed. Check column {_ML_EMBED_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + return df