From 53c51f9e87815f0d380b31b3b312e3e7271ed7fd Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Wed, 12 Mar 2025 22:56:07 +0000 Subject: [PATCH 1/2] performance: eliminate count queries in llm retry --- bigframes/ml/base.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index a0800c19e6..796679accb 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -259,38 +259,29 @@ def _predict_and_retry( ) -> bpd.DataFrame: assert self._bqml_model is not None - df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder + df_result = X.iloc[:0] # placeholder df_fail = X - for _ in range(max_retries + 1): + for i in range(max_retries + 1): + if i > 0 and df_fail.empty: + break + df = self._predict_func(df_fail, options) success = df[self._status_col].str.len() == 0 df_succ = df[success] df_fail = df[~success] - if df_succ.empty: - if max_retries > 0: + if max_retries > 0: + if df_succ.empty: msg = bfe.format_message("Can't make any progress, stop retrying.") warnings.warn(msg, category=RuntimeWarning) break - df_result = ( - bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ - ) - - if df_fail.empty: - break - - if not df_fail.empty: - msg = bfe.format_message( - f"Some predictions failed. Check column {self._status_col} for detailed " - "status. You may want to filter the failed rows and retry." - ) - warnings.warn(msg, category=RuntimeWarning) + df_result = bpd.concat([df_result, df_succ]) df_result = cast( bpd.DataFrame, - bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, + bpd.concat([df_result, df_fail]), ) return df_result From 262dff80915aaae6d9f0c8ae710ae1e36a99286a Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Thu, 13 Mar 2025 17:54:15 +0000 Subject: [PATCH 2/2] fix tests --- bigframes/ml/base.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 796679accb..2b25bc82f0 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -22,7 +22,7 @@ """ import abc -from typing import Callable, cast, Mapping, Optional, TypeVar +from typing import Callable, cast, Mapping, Optional, TypeVar, Union import warnings import bigframes_vendored.sklearn.base @@ -259,11 +259,15 @@ def _predict_and_retry( ) -> bpd.DataFrame: assert self._bqml_model is not None - df_result = X.iloc[:0] # placeholder - df_fail = X + df_result: Union[bpd.DataFrame, None] = None # placeholder + df_succ = df_fail = X for i in range(max_retries + 1): if i > 0 and df_fail.empty: break + if i > 0 and df_succ.empty: + msg = bfe.format_message("Can't make any progress, stop retrying.") + warnings.warn(msg, category=RuntimeWarning) + break df = self._predict_func(df_fail, options) @@ -271,17 +275,13 @@ def _predict_and_retry( df_succ = df[success] df_fail = df[~success] - if max_retries > 0: - if df_succ.empty: - msg = bfe.format_message("Can't make any progress, stop retrying.") - warnings.warn(msg, category=RuntimeWarning) - break - - df_result = bpd.concat([df_result, df_succ]) + df_result = ( + bpd.concat([df_result, df_succ]) if df_result is not None else df_succ + ) df_result = cast( bpd.DataFrame, - bpd.concat([df_result, df_fail]), + bpd.concat([df_result, df_fail]) if df_result is not None else df_fail, ) return df_result