From fc223f87bac09019914aef4131d99bfc9a507cfc Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Mon, 19 Aug 2024 20:37:52 +0000 Subject: [PATCH 1/2] feat: add llm.TextEmbeddingGenerator to support new embedding models --- bigframes/ml/llm.py | 166 +++++++++++++++++++++++++++++- bigframes/ml/loader.py | 3 + tests/system/small/ml/test_llm.py | 41 ++++++++ 3 files changed, 207 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 2517178d89..ff4710258a 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -40,11 +40,18 @@ _EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko" _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual" -_EMBEDDING_GENERATOR_ENDPOINTS = ( +_PALM2_EMBEDDING_GENERATOR_ENDPOINTS = ( _EMBEDDING_GENERATOR_GECKO_ENDPOINT, _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT, ) +_TEXT_EMBEDDING_004_ENDPOINT = "text-embedding-004" +_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT = "text-multilingual-embedding-002" +_TEXT_EMBEDDING_ENDPOINTS = ( + _TEXT_EMBEDDING_004_ENDPOINT, + _TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT, +) + _GEMINI_PRO_ENDPOINT = "gemini-pro" _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514" _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514" @@ -57,6 +64,7 @@ _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" _ML_EMBED_TEXT_STATUS = "ml_embed_text_status" +_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status" @log_adapter.class_logger @@ -387,6 +395,10 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: class PaLM2TextEmbeddingGenerator(base.BaseEstimator): """PaLM2 text embedding generator LLM model. + .. note:: + Models in this class are outdated and going to be deprecated. To use the most updated text embedding models, go to the TextEmbeddingGenerator class. + + Args: model_name (str, Default to "textembedding-gecko"): The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs. @@ -447,9 +459,9 @@ def _create_bqml_model(self): iam_role="aiplatform.user", ) - if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: + if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS: raise ValueError( - f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}." + f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}." ) endpoint = ( @@ -551,6 +563,154 @@ def to_gbq( return new_model.session.read_gbq_model(model_name) +@log_adapter.class_logger +class TextEmbeddingGenerator(base.BaseEstimator): + """Text embedding generator LLM model. + + Args: + model_name (str, Default to "text-embedding-004"): + The model for text embedding. Possible values are "text-embedding-004" or "text-multilingual-embedding-002". + text-embedding models returns model embeddings for text inputs. + text-multilingual-embedding models returns model embeddings for text inputs which support over 100 languages. + Default to "text-embedding-004". + session (bigframes.Session or None): + BQ session to create the model. If None, use the global default session. + connection_name (str or None): + Connection to connect with remote service. str of the format ... + If None, use default connection in session context. + """ + + def __init__( + self, + *, + model_name: Literal[ + "text-embedding-004", "text-multilingual-embedding-002" + ] = "text-embedding-004", + session: Optional[bigframes.Session] = None, + connection_name: Optional[str] = None, + ): + self.model_name = model_name + self.session = session or bpd.get_global_session() + self._bq_connection_manager = self.session.bqconnectionmanager + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.resolve_full_bq_connection_name( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + + self._bqml_model_factory = globals.bqml_model_factory() + self._bqml_model: core.BqmlModel = self._create_bqml_model() + + def _create_bqml_model(self): + # Parse and create connection if needed. + if not self.connection_name: + raise ValueError( + "Must provide connection_name, either in constructor or through session options." + ) + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", + ) + + if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS: + raise ValueError( + f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}." + ) + + options = { + "endpoint": self.model_name, + } + return self._bqml_model_factory.create_remote_model( + session=self.session, connection_name=self.connection_name, options=options + ) + + @classmethod + def _from_bq( + cls, session: bigframes.Session, bq_model: bigquery.Model + ) -> TextEmbeddingGenerator: + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in bq_model._properties + assert "endpoint" in bq_model._properties["remoteModelInfo"] + assert "connection" in bq_model._properties["remoteModelInfo"] + + # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + model_connection = bq_model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] + + model = cls( + session=session, + model_name=model_endpoint, # type: ignore + connection_name=model_connection, + ) + + model._bqml_model = core.BqmlModel(session, bq_model) + return model + + def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: + """Predict the result from input DataFrame. + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series): + Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples. + + Returns: + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. + """ + + # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models + (X,) = utils.convert_to_dataframe(X) + + if len(X.columns) != 1: + raise ValueError( + f"Only support one column as input. {constants.FEEDBACK_LINK}" + ) + + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "content"}) + + options = { + "flatten_json_output": True, + } + + df = self._bqml_model.generate_embedding(X, options) + + if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any(): + warnings.warn( + f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + return df + + def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + The name of the model. + replace (bool, default False): + Determine whether to replace if the model already exists. Default to False. + + Returns: + PaLM2TextEmbeddingGenerator: Saved model.""" + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) + + @log_adapter.class_logger class GeminiTextGenerator(base.BaseEstimator): """Gemini text generator LLM model. diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 515fb50c6f..bd01342152 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -63,6 +63,8 @@ llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, + llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator, + llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator, } ) @@ -84,6 +86,7 @@ def from_bq( imported.XGBoostModel, llm.PaLM2TextGenerator, llm.PaLM2TextEmbeddingGenerator, + llm.TextEmbeddingGenerator, pipeline.Pipeline, compose.ColumnTransformer, preprocessing.PreprocessingType, diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index b926004fd8..c2f62096d0 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -304,6 +304,47 @@ def test_embedding_generator_predict_series_success( assert len(value) == 768 +@pytest.mark.parametrize( + "model_name", + ("text-embedding-004", "text-multilingual-embedding-002"), +) +def test_create_load_text_embedding_generator_model( + dataset_id, model_name, session, bq_connection +): + text_embedding_model = llm.TextEmbeddingGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + assert text_embedding_model is not None + assert text_embedding_model._bqml_model is not None + + # save, load to ensure configuration was kept + reloaded_model = text_embedding_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.connection_name == bq_connection + assert reloaded_model.model_name == model_name + + +@pytest.mark.parametrize( + "model_name", + ("text-embedding-004", "text-multilingual-embedding-002"), +) +@pytest.mark.flaky(retries=2) +def test_gemini_text_embedding_generator_predict_default_params_success( + llm_text_df, model_name, session, bq_connection +): + text_embedding_model = llm.TextEmbeddingGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + df = text_embedding_model.predict(llm_text_df).to_pandas() + assert df.shape == (3, 4) + assert "ml_generate_embedding_result" in df.columns + series = df["ml_generate_embedding_result"] + value = series[0] + assert len(value) == 768 + + @pytest.mark.parametrize( "model_name", ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), From c69b947920938a8792433899b3ad530d8ac51208 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Mon, 19 Aug 2024 22:18:31 +0000 Subject: [PATCH 2/2] fix docs --- bigframes/ml/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index ff4710258a..45634423c6 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -705,7 +705,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat Determine whether to replace if the model already exists. Default to False. Returns: - PaLM2TextEmbeddingGenerator: Saved model.""" + TextEmbeddingGenerator: Saved model.""" new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name)