From f003deb2bd355cc2c1f8a2886ed8596a62f6986a Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 31 Dec 2024 02:14:13 +0000 Subject: [PATCH 1/2] chore: fix wordings of Gemini max_retries --- bigframes/ml/llm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 2427009cf1..d42138b006 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -986,9 +986,8 @@ def predict( The default is `False`. max_retries (int, default 0): - Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round. - Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result. - + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. """ @@ -1034,11 +1033,15 @@ def predict( for _ in range(max_retries + 1): df = self._bqml_model.generate_text(df_fail, options) - df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0] - df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0] + success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0 + df_succ = df[success] + df_fail = df[~success] if df_succ.empty: - warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning) + if max_retries > 0: + warnings.warn( + "Can't make any progress, stop retrying.", RuntimeWarning + ) break df_result = ( From 726dfc25f218028013fbf099c1605b3e3f88cfb2 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Sat, 4 Jan 2025 00:42:49 +0000 Subject: [PATCH 2/2] feat: add max_retries to TextEmbeddingGenerator and Claude3TextGenerator --- bigframes/ml/base.py | 64 ++++++- bigframes/ml/llm.py | 109 ++++++----- tests/system/small/ml/test_llm.py | 288 +++++++++++++++++++++++++++--- 3 files changed, 369 insertions(+), 92 deletions(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 4058647adb..a2c122f8c7 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -22,7 +22,8 @@ """ import abc -from typing import cast, Optional, TypeVar +from typing import Callable, cast, Mapping, Optional, TypeVar +import warnings import bigframes_vendored.sklearn.base @@ -77,6 +78,9 @@ def fit_transform(self, x_train: Union[DataFrame, Series], y_train: Union[DataFr ... """ + def __init__(self): + self._bqml_model: Optional[core.BqmlModel] = None + def __repr__(self): """Print the estimator's constructor with all non-default parameter values.""" @@ -95,9 +99,6 @@ def __repr__(self): class Predictor(BaseEstimator): """A BigQuery DataFrames ML Model base class that can be used to predict outputs.""" - def __init__(self): - self._bqml_model: Optional[core.BqmlModel] = None - @abc.abstractmethod def predict(self, X): pass @@ -213,12 +214,61 @@ def fit( return self._fit(X, y) +class RetriableRemotePredictor(BaseEstimator): + @property + @abc.abstractmethod + def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]: + pass + + @property + @abc.abstractmethod + def _status_col(self) -> str: + pass + + def _predict_and_retry( + self, X: bpd.DataFrame, options: Mapping, max_retries: int + ) -> bpd.DataFrame: + assert self._bqml_model is not None + + df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder + df_fail = X + for _ in range(max_retries + 1): + df = self._predict_func(df_fail, options) + + success = df[self._status_col].str.len() == 0 + df_succ = df[success] + df_fail = df[~success] + + if df_succ.empty: + if max_retries > 0: + warnings.warn( + "Can't make any progress, stop retrying.", RuntimeWarning + ) + break + + df_result = ( + bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ + ) + + if df_fail.empty: + break + + if not df_fail.empty: + warnings.warn( + f"Some predictions failed. Check column {self._status_col} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + df_result = cast( + bpd.DataFrame, + bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, + ) + return df_result + + class BaseTransformer(BaseEstimator): """Transformer base class.""" - def __init__(self): - self._bqml_model: Optional[core.BqmlModel] = None - @abc.abstractmethod def _keys(self): pass diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index d42138b006..e6825f80bb 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import cast, Literal, Optional +from typing import Callable, cast, Literal, Mapping, Optional import warnings import bigframes_vendored.constants as constants @@ -616,7 +616,7 @@ def to_gbq( @log_adapter.class_logger -class TextEmbeddingGenerator(base.BaseEstimator): +class TextEmbeddingGenerator(base.RetriableRemotePredictor): """Text embedding generator LLM model. Args: @@ -715,18 +715,33 @@ def _from_bq( model._bqml_model = core.BqmlModel(session, bq_model) return model - def predict(self, X: utils.ArrayType) -> bpd.DataFrame: + @property + def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]: + return self._bqml_model.generate_embedding + + @property + def _status_col(self) -> str: + return _ML_GENERATE_EMBEDDING_STATUS + + def predict(self, X: utils.ArrayType, *, max_retries: int = 0) -> bpd.DataFrame: """Predict the result from input DataFrame. Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series): Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction. + max_retries (int, default 0): + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. + Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. """ + if max_retries < 0: + raise ValueError( + f"max_retries must be larger than or equal to 0, but is {max_retries}." + ) - # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) if len(X.columns) == 1: @@ -738,15 +753,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame: "flatten_json_output": True, } - df = self._bqml_model.generate_embedding(X, options) - - if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any(): - warnings.warn( - f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.", - RuntimeWarning, - ) - - return df + return self._predict_and_retry(X, options=options, max_retries=max_retries) def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator: """Save the model to BigQuery. @@ -765,7 +772,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat @log_adapter.class_logger -class GeminiTextGenerator(base.BaseEstimator): +class GeminiTextGenerator(base.RetriableRemotePredictor): """Gemini text generator LLM model. Args: @@ -891,6 +898,14 @@ def _bqml_options(self) -> dict: } return options + @property + def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]: + return self._bqml_model.generate_text + + @property + def _status_col(self) -> str: + return _ML_GENERATE_TEXT_STATUS + def fit( self, X: utils.ArrayType, @@ -1028,41 +1043,7 @@ def predict( "ground_with_google_search": ground_with_google_search, } - df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder - df_fail = X - for _ in range(max_retries + 1): - df = self._bqml_model.generate_text(df_fail, options) - - success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0 - df_succ = df[success] - df_fail = df[~success] - - if df_succ.empty: - if max_retries > 0: - warnings.warn( - "Can't make any progress, stop retrying.", RuntimeWarning - ) - break - - df_result = ( - bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ - ) - - if df_fail.empty: - break - - if not df_fail.empty: - warnings.warn( - f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", - RuntimeWarning, - ) - - df_result = cast( - bpd.DataFrame, - bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, - ) - - return df_result + return self._predict_and_retry(X, options=options, max_retries=max_retries) def score( self, @@ -1144,7 +1125,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator: @log_adapter.class_logger -class Claude3TextGenerator(base.BaseEstimator): +class Claude3TextGenerator(base.RetriableRemotePredictor): """Claude3 text generator LLM model. Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. @@ -1273,6 +1254,14 @@ def _bqml_options(self) -> dict: } return options + @property + def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]: + return self._bqml_model.generate_text + + @property + def _status_col(self) -> str: + return _ML_GENERATE_TEXT_STATUS + def predict( self, X: utils.ArrayType, @@ -1280,6 +1269,7 @@ def predict( max_output_tokens: int = 128, top_k: int = 40, top_p: float = 0.95, + max_retries: int = 0, ) -> bpd.DataFrame: """Predict the result from input DataFrame. @@ -1307,6 +1297,10 @@ def predict( Specify a lower value for less random responses and a higher value for more random responses. Default 0.95. Possible values [0.0, 1.0]. + max_retries (int, default 0): + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. + Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. @@ -1324,6 +1318,11 @@ def predict( if top_p < 0.0 or top_p > 1.0: raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.") + if max_retries < 0: + raise ValueError( + f"max_retries must be larger than or equal to 0, but is {max_retries}." + ) + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) if len(X.columns) == 1: @@ -1338,15 +1337,7 @@ def predict( "flatten_json_output": True, } - df = self._bqml_model.generate_text(X, options) - - if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): - warnings.warn( - f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", - RuntimeWarning, - ) - - return df + return self._predict_and_retry(X, options=options, max_retries=max_retries) def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator: """Save the model to BigQuery. diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 304204cc7b..29f504443a 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -381,7 +381,35 @@ def __eq__(self, other): return self.equals(other) -def test_gemini_text_generator_retry_success(session, bq_connection): +@pytest.mark.parametrize( + ( + "model_class", + "options", + ), + [ + ( + llm.GeminiTextGenerator, + { + "temperature": 0.9, + "max_output_tokens": 8192, + "top_k": 40, + "top_p": 1.0, + "flatten_json_output": True, + "ground_with_google_search": False, + }, + ), + ( + llm.Claude3TextGenerator, + { + "max_output_tokens": 128, + "top_k": 40, + "top_p": 0.95, + "flatten_json_output": True, + }, + ), + ], +) +def test_text_generator_retry_success(session, bq_connection, model_class, options): # Requests. df0 = EqCmpAllDataFrame( { @@ -455,22 +483,12 @@ def test_gemini_text_generator_retry_success(session, bq_connection): session=session, ), ] - options = { - "temperature": 0.9, - "max_output_tokens": 8192, - "top_k": 40, - "top_p": 1.0, - "flatten_json_output": True, - "ground_with_google_search": False, - } - gemini_text_generator_model = llm.GeminiTextGenerator( - connection_name=bq_connection, session=session - ) - gemini_text_generator_model._bqml_model = mock_bqml_model + text_generator_model = model_class(connection_name=bq_connection, session=session) + text_generator_model._bqml_model = mock_bqml_model # 3rd retry isn't triggered - result = gemini_text_generator_model.predict(df0, max_retries=3) + result = text_generator_model.predict(df0, max_retries=3) mock_bqml_model.generate_text.assert_has_calls( [ @@ -497,7 +515,35 @@ def test_gemini_text_generator_retry_success(session, bq_connection): ) -def test_gemini_text_generator_retry_no_progress(session, bq_connection): +@pytest.mark.parametrize( + ( + "model_class", + "options", + ), + [ + ( + llm.GeminiTextGenerator, + { + "temperature": 0.9, + "max_output_tokens": 8192, + "top_k": 40, + "top_p": 1.0, + "flatten_json_output": True, + "ground_with_google_search": False, + }, + ), + ( + llm.Claude3TextGenerator, + { + "max_output_tokens": 128, + "top_k": 40, + "top_p": 0.95, + "flatten_json_output": True, + }, + ), + ], +) +def test_text_generator_retry_no_progress(session, bq_connection, model_class, options): # Requests. df0 = EqCmpAllDataFrame( { @@ -550,24 +596,214 @@ def test_gemini_text_generator_retry_no_progress(session, bq_connection): session=session, ), ] + + text_generator_model = model_class(connection_name=bq_connection, session=session) + text_generator_model._bqml_model = mock_bqml_model + + # No progress, only conduct retry once + result = text_generator_model.predict(df0, max_retries=3) + + mock_bqml_model.generate_text.assert_has_calls( + [ + mock.call(df0, options), + mock.call(df1, options), + ] + ) + pd.testing.assert_frame_equal( + result.to_pandas(), + pd.DataFrame( + { + "ml_generate_text_status": ["", "error", "error"], + "prompt": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + ), + check_dtype=False, + check_index_type=False, + ) + + +def test_text_embedding_generator_retry_success(session, bq_connection): + # Requests. + df0 = EqCmpAllDataFrame( + { + "content": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ] + }, + index=[0, 1, 2], + session=session, + ) + df1 = EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["error", "error"], + "content": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ) + df2 = EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["error"], + "content": [ + "What is BQML?", + ], + }, + index=[1], + session=session, + ) + + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) + + # Responses. Retry twice then all succeeded. + mock_bqml_model.generate_embedding.side_effect = [ + EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["", "error", "error"], + "content": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["error", ""], + "content": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_embedding_status": [""], + "content": [ + "What is BQML?", + ], + }, + index=[1], + session=session, + ), + ] options = { - "temperature": 0.9, - "max_output_tokens": 8192, - "top_k": 40, - "top_p": 1.0, "flatten_json_output": True, - "ground_with_google_search": False, } - gemini_text_generator_model = llm.GeminiTextGenerator( + text_embedding_model = llm.TextEmbeddingGenerator( + connection_name=bq_connection, session=session + ) + text_embedding_model._bqml_model = mock_bqml_model + + # 3rd retry isn't triggered + result = text_embedding_model.predict(df0, max_retries=3) + + mock_bqml_model.generate_embedding.assert_has_calls( + [ + mock.call(df0, options), + mock.call(df1, options), + mock.call(df2, options), + ] + ) + pd.testing.assert_frame_equal( + result.to_pandas(), + pd.DataFrame( + { + "ml_generate_embedding_status": ["", "", ""], + "content": [ + "What is BigQuery?", + "What is BigQuery DataFrame?", + "What is BQML?", + ], + }, + index=[0, 2, 1], + ), + check_dtype=False, + check_index_type=False, + ) + + +def test_text_embedding_generator_retry_no_progress(session, bq_connection): + # Requests. + df0 = EqCmpAllDataFrame( + { + "content": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ] + }, + index=[0, 1, 2], + session=session, + ) + df1 = EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["error", "error"], + "content": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ) + + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) + # Responses. Retry once, no progress, just stop. + mock_bqml_model.generate_embedding.side_effect = [ + EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["", "error", "error"], + "content": [ + "What is BigQuery?", + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[0, 1, 2], + session=session, + ), + EqCmpAllDataFrame( + { + "ml_generate_embedding_status": ["error", "error"], + "content": [ + "What is BQML?", + "What is BigQuery DataFrame?", + ], + }, + index=[1, 2], + session=session, + ), + ] + options = { + "flatten_json_output": True, + } + + text_embedding_model = llm.TextEmbeddingGenerator( connection_name=bq_connection, session=session ) - gemini_text_generator_model._bqml_model = mock_bqml_model + text_embedding_model._bqml_model = mock_bqml_model # No progress, only conduct retry once - result = gemini_text_generator_model.predict(df0, max_retries=3) + result = text_embedding_model.predict(df0, max_retries=3) - mock_bqml_model.generate_text.assert_has_calls( + mock_bqml_model.generate_embedding.assert_has_calls( [ mock.call(df0, options), mock.call(df1, options), @@ -577,8 +813,8 @@ def test_gemini_text_generator_retry_no_progress(session, bq_connection): result.to_pandas(), pd.DataFrame( { - "ml_generate_text_status": ["", "error", "error"], - "prompt": [ + "ml_generate_embedding_status": ["", "error", "error"], + "content": [ "What is BigQuery?", "What is BQML?", "What is BigQuery DataFrame?",