diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index a0800c19e6..2b25bc82f0 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -22,7 +22,7 @@ """ import abc -from typing import Callable, cast, Mapping, Optional, TypeVar +from typing import Callable, cast, Mapping, Optional, TypeVar, Union import warnings import bigframes_vendored.sklearn.base @@ -259,38 +259,29 @@ def _predict_and_retry( ) -> bpd.DataFrame: assert self._bqml_model is not None - df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder - df_fail = X - for _ in range(max_retries + 1): + df_result: Union[bpd.DataFrame, None] = None # placeholder + df_succ = df_fail = X + for i in range(max_retries + 1): + if i > 0 and df_fail.empty: + break + if i > 0 and df_succ.empty: + msg = bfe.format_message("Can't make any progress, stop retrying.") + warnings.warn(msg, category=RuntimeWarning) + break + df = self._predict_func(df_fail, options) success = df[self._status_col].str.len() == 0 df_succ = df[success] df_fail = df[~success] - if df_succ.empty: - if max_retries > 0: - msg = bfe.format_message("Can't make any progress, stop retrying.") - warnings.warn(msg, category=RuntimeWarning) - break - df_result = ( - bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ - ) - - if df_fail.empty: - break - - if not df_fail.empty: - msg = bfe.format_message( - f"Some predictions failed. Check column {self._status_col} for detailed " - "status. You may want to filter the failed rows and retry." + bpd.concat([df_result, df_succ]) if df_result is not None else df_succ ) - warnings.warn(msg, category=RuntimeWarning) df_result = cast( bpd.DataFrame, - bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail, + bpd.concat([df_result, df_fail]) if df_result is not None else df_fail, ) return df_result