diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 2f3b532a74..9617b5d7a5 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -123,10 +123,15 @@ def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame: self._model_manipulation_sql_generator.ml_predict, ) - def explain_predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame: + def explain_predict( + self, input_data: bpd.DataFrame, options: Mapping[str, int | float] + ) -> bpd.DataFrame: return self._apply_ml_tvf( input_data, - self._model_manipulation_sql_generator.ml_explain_predict, + lambda source_sql: self._model_manipulation_sql_generator.ml_explain_predict( + source_sql=source_sql, + struct_options=options, + ), ) def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame: diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index eac0fd1fca..722b72f806 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -155,14 +155,15 @@ def _fit( def predict(self, X: utils.ArrayType) -> bpd.DataFrame: if not self._bqml_model: raise RuntimeError("A model must be fitted before predict") - - (X,) = utils.batch_convert_to_dataframe(X) + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) return self._bqml_model.predict(X) def predict_explain( self, X: utils.ArrayType, + *, + top_k_features: int = 5, ) -> bpd.DataFrame: """ Explain predictions for a linear regression model. @@ -175,18 +176,32 @@ def predict_explain( X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series): Series or a DataFrame to explain its predictions. + top_k_features (int, default 5): + an INT64 value that specifies how many top feature attribution + pairs are generated for each row of input data. The features are + ranked by the absolute values of their attributions. + + By default, top_k_features is set to 5. If its value is greater + than the number of features in the training data, the + attributions of all features are returned. Returns: bigframes.pandas.DataFrame: The predicted DataFrames with explanation columns. """ - # TODO(b/377366612): Add support for `top_k_features` parameter + if top_k_features < 1: + raise ValueError( + f"top_k_features must be at least 1, but is {top_k_features}." + ) + if not self._bqml_model: raise RuntimeError("A model must be fitted before predict") (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) - return self._bqml_model.explain_predict(X) + return self._bqml_model.explain_predict( + X, options={"top_k_features": top_k_features} + ) def score( self, @@ -356,6 +371,8 @@ def predict( def predict_explain( self, X: utils.ArrayType, + *, + top_k_features: int = 5, ) -> bpd.DataFrame: """ Explain predictions for a logistic regression model. @@ -368,18 +385,32 @@ def predict_explain( X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series): Series or a DataFrame to explain its predictions. + top_k_features (int, default 5): + an INT64 value that specifies how many top feature attribution + pairs are generated for each row of input data. The features are + ranked by the absolute values of their attributions. + + By default, top_k_features is set to 5. If its value is greater + than the number of features in the training data, the + attributions of all features are returned. Returns: bigframes.pandas.DataFrame: The predicted DataFrames with explanation columns. """ - # TODO(b/377366612): Add support for `top_k_features` parameter + if top_k_features < 1: + raise ValueError( + f"top_k_features must be at least 1, but is {top_k_features}." + ) + if not self._bqml_model: raise RuntimeError("A model must be fitted before predict") (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) - return self._bqml_model.explain_predict(X) + return self._bqml_model.explain_predict( + X, options={"top_k_features": top_k_features} + ) def score( self, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 93b8a3a051..b662d4c22c 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -304,10 +304,13 @@ def ml_predict(self, source_sql: str) -> str: return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()}, ({source_sql}))""" - def ml_explain_predict(self, source_sql: str) -> str: + def ml_explain_predict( + self, source_sql: str, struct_options: Mapping[str, Union[int, float]] + ) -> str: """Encode ML.EXPLAIN_PREDICT for BQML""" + struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()}, - ({source_sql}))""" + ({source_sql}), {struct_options_sql})""" def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: """Encode ML.FORECAST for BQML""" diff --git a/tests/system/small/ml/test_core.py b/tests/system/small/ml/test_core.py index b9748f24d3..2a2e68b230 100644 --- a/tests/system/small/ml/test_core.py +++ b/tests/system/small/ml/test_core.py @@ -263,8 +263,9 @@ def test_model_predict(penguins_bqml_linear_model: core.BqmlModel, new_penguins_ def test_model_predict_explain( penguins_bqml_linear_model: core.BqmlModel, new_penguins_df ): + options = {"top_k_features": 3} predictions = penguins_bqml_linear_model.explain_predict( - new_penguins_df + new_penguins_df, options ).to_pandas() expected = pd.DataFrame( { @@ -317,6 +318,7 @@ def test_model_predict_explain_with_unnamed_index( # need to persist through the call to ML.PREDICT new_penguins_df = new_penguins_df.reset_index() + options = {"top_k_features": 3} # remove the middle tag number to ensure we're really keeping the unnamed index new_penguins_df = typing.cast( bigframes.dataframe.DataFrame, @@ -324,7 +326,7 @@ def test_model_predict_explain_with_unnamed_index( ) predictions = penguins_bqml_linear_model.explain_predict( - new_penguins_df + new_penguins_df, options ).to_pandas() expected = pd.DataFrame( diff --git a/tests/system/small/ml/test_linear_model.py b/tests/system/small/ml/test_linear_model.py index 3be1147c1e..da9fc8e14f 100644 --- a/tests/system/small/ml/test_linear_model.py +++ b/tests/system/small/ml/test_linear_model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import google.api_core.exceptions import pandas import pytest @@ -132,6 +134,20 @@ def test_linear_reg_model_predict_explain(penguins_linear_model, new_penguins_df ) +def test_linear_model_predict_explain_top_k_features( + penguins_logistic_model: linear_model.LinearRegression, new_penguins_df +): + top_k_features = 0 + + with pytest.raises( + ValueError, + match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."), + ): + penguins_logistic_model.predict_explain( + new_penguins_df, top_k_features=top_k_features + ).to_pandas() + + def test_linear_reg_model_predict_params( penguins_linear_model: linear_model.LinearRegression, new_penguins_df ): @@ -307,6 +323,20 @@ def test_logistic_model_predict(penguins_logistic_model, new_penguins_df): ) +def test_logistic_model_predict_explain_top_k_features( + penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df +): + top_k_features = 0 + + with pytest.raises( + ValueError, + match=re.escape(f"top_k_features must be at least 1, but is {top_k_features}."), + ): + penguins_logistic_model.predict_explain( + new_penguins_df, top_k_features=top_k_features + ).to_pandas() + + def test_logistic_model_predict_params( penguins_logistic_model: linear_model.LogisticRegression, new_penguins_df ): diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index 9d18649efe..5a7220fc38 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -342,18 +342,6 @@ def test_ml_predict_correct( ) -def test_ml_explain_predict_correct( - model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, - mock_df: bpd.DataFrame, -): - sql = model_manipulation_sql_generator.ml_explain_predict(source_sql=mock_df.sql) - assert ( - sql - == """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, - (input_X_y_sql))""" - ) - - def test_ml_llm_evaluate_correct( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, mock_df: bpd.DataFrame, @@ -462,6 +450,23 @@ def test_ml_generate_embedding_correct( ) +def test_ml_explain_predict_correct( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, + mock_df: bpd.DataFrame, +): + sql = model_manipulation_sql_generator.ml_explain_predict( + source_sql=mock_df.sql, + struct_options={"option_key1": 1, "option_key2": 2.25}, + ) + assert ( + sql + == """SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, + (input_X_y_sql), STRUCT( + 1 AS `option_key1`, + 2.25 AS `option_key2`))""" + ) + + def test_ml_detect_anomalies_correct_sql( model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, mock_df: bpd.DataFrame,