8000 API Deprecates values in partial_dependence in favor of pdp_values (#… · scikit-learn/scikit-learn@fabe160 · GitHub
[go: up one dir, main page]

Skip to content

Commit fabe160

Browse files
thomasjpfanglemaitrejeremiedbb
authored
API Deprecates values in partial_dependence in favor of pdp_values (#21809)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 96625cf commit fabe160

File tree

7 files changed

+133
-27
lines changed

7 files changed

+133
-27
lines changed

doc/whats_new/v1.3.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ Changelog
168168
- |Enhancement| Added the parameter `fill_value` to :class:`impute.IterativeImputer`.
169169
:pr:`25232` by :user:`Thijs van Weezel <ValueInvestorThijs>`.
170170

171+
:mod:`sklearn.inspection`
172+
.........................
173+
174+
- |API| :func:`inspection.partial_dependence` returns a :class:`utils.Bunch` with
175+
new key: `pdp_values`. The `values` key is deprecated in favor of `pdp_values`
176+
and the `values` key will be removed in 1.5.
177+
:pr:`21809` by `Thomas Fan`_.
178+
171179
:mod:`sklearn.linear_model`
172180
...........................
173181

sklearn/inspection/_partial_dependence.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -365,16 +365,26 @@ def partial_dependence(
365365
Only available when ``kind='both'``.
366366
367367
values : seq of 1d ndarrays
368+
The values with which the grid has been created.
369+
370+
.. deprecated:: 1.3
371+
The key `values` has been deprecated in 1.3 and will be removed
372+
in 1.5 in favor of `pdp_values`. See `pdp_values` for details
373+
about the `values` attribute.
374+
375+
pdp_values : seq of 1d ndarrays
368376
The values with which the grid has been created. The generated
369-
grid is a cartesian product of the arrays in ``values``.
370-
``len(values) == len(features)``. The size of each array
371-
``values[j]`` is either ``grid_resolution``, or the number of
377+
grid is a cartesian product of the arrays in ``pdp_values``.
378+
``len(pdp_values) == len(features)``. The size of each array
379+
``pdp_values[j]`` is either ``grid_resolution``, or the number of
372380
unique values in ``X[:, j]``, whichever is smaller.
373381
382+
.. versionadded:: 1.3
383+
374384
``n_outputs`` corresponds to the number of classes in a multi-class
375385
setting, or to the number of tasks for multi-output regression.
376386
For classical regression and binary classification ``n_outputs==1``.
377-
``n_values_feature_j`` corresponds to the size ``values[j]``.
387+
``n_values_feature_j`` corresponds to the size ``pdp_values[j]``.
378388
379389
See Also
380390
--------
@@ -547,14 +557,22 @@ def partial_dependence(
547557
averaged_predictions = averaged_predictions.reshape(
548558
-1, *[val.shape[0] for val in values]
549559
)
560+
pdp_results = Bunch()
561+
562+
msg = (
563+
"Key: 'values', is deprecated in 1.3 and will be removed in 1.5. "
564+
"Please use 'pdp_values' instead."
565+
)
566+
pdp_results._set_deprecated(
567+
values, new_key="pdp_values", deprecated_key="values", warning_message=msg
568+
)
550569

551570
if kind == "average":
552-
return Bunch(average=averaged_predictions, values=values)
571+
pdp_results["average"] = averaged_predictions
553572
elif kind == "individual":
554-
return Bunch(individual=predictions, values=values)
573+
pdp_results["individual"] = predictions
555574
else: # kind='both'
556-
return Bunch(
557-
average=averaged_predictions,
558-
individual=predictions,
559-
values=values,
560-
)
575+
pdp_results["average"] = averaged_predictions
576+
pdp_results["individual"] = predictions
577+
578+
return pdp_results

sklearn/inspection/_plot/partial_dependence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ def plot(
12561256
else:
12571257
pd_results_ = []
12581258
for kind_plot, pd_result in zip(kind, self.pd_results):
1259-
current_results = {"values": pd_result["values"]}
1259+
current_results = {"pdp_values": pd_result["pdp_values"]}
12601260

12611261
if kind_plot in ("individual", "both"):
12621262
preds = pd_result.individual
@@ -1274,7 +1274,7 @@ def plot(
12741274
# get global min and max average predictions of PD grouped by plot type
12751275
pdp_lim = {}
12761276
for kind_plot, pdp in zip(kind, pd_results_):
1277-
values = pdp["values"]
1277+
values = pdp["pdp_values"]
12781278
preds = pdp.average if kind_plot == "average" else pdp.individual
12791279
min_pd = preds[self.target_idx].min()
12801280
max_pd = preds[self.target_idx].max()
@@ -1402,7 +1402,7 @@ def plot(
14021402
):
14031403
avg_preds = None
14041404
preds = None
1405-
feature_values = pd_result["values"]
1405+
feature_values = pd_result["pdp_values"]
14061406
if kind_plot == "individual":
14071407
preds = pd_result.individual
14081408
elif kind_plot == "average":

sklearn/inspection/_plot/tests/test_plot_partial_dependence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_plot_partial_dependence(grid_resolution, pyplot, clf_diabetes, diabetes
103103
target_idx = disp.target_idx
104104

105105
line_data = line.get_data()
106-
assert_allclose(line_data[0], avg_preds["values"][0])
106+
assert_allclose(line_data[0], avg_preds["pdp_values"][0])
107107
assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
108108

109109
# two feature position
@@ -243,7 +243,7 @@ def test_plot_partial_dependence_str_features(
243243
assert line.get_alpha() == 0.8
244244

245245
line_data = line.get_data()
246-
assert_allclose(line_data[0], avg_preds["values"][0])
246+
assert_allclose(line_data[0], avg_preds["pdp_values"][0])
247247
assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
248248

249249
# contour
@@ -279,7 +279,7 @@ def test_plot_partial_dependence_custom_axes(pyplot, clf_diabetes, diabetes):
279279
target_idx = disp.target_idx
280280

281281
line_data = line.get_data()
282-
assert_allclose(line_data[0], avg_preds["values"][0])
282+
assert_allclose(line_data[0], avg_preds["pdp_values"][0])
283283
assert_allclose(line_data[1], avg_preds.average[target_idx].ravel())
284284

285285
# contour
@@ -466,7 +466,7 @@ def test_plot_partial_dependence_multiclass(pyplot):
466466
disp_target_0.pd_results, disp_symbol.pd_results
467467
):
468468
assert_allclose(int_result.average, symbol_result.average)
469-
assert_allclose(int_result["values"], symbol_result["values"])
469+
assert_allclose(int_result["pdp_values"], symbol_result["pdp_values"])
470470

471471
# check that the pd plots are different for another target
472472
disp_target_1 = PartialDependenceDisplay.from_estimator(

sklearn/inspection/tests/test_partial_dependence.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Testing for the partial dependence module.
33
"""
4+
import warnings
45

56
import numpy as np
67
import pytest
@@ -108,7 +109,7 @@ def test_output_shape(Estimator, method, data, grid_resolution, features, kind):
108109
kind=kind,
109110
grid_resolution=grid_resolution,
110111
)
111-
pdp, axes = result, result["values"]
112+
pdp, axes = result, result["pdp_values"]
112113

113114
expected_pdp_shape = (n_targets, *[grid_resolution for _ in range(len(features))])
114115
expected_ice_shape = (
@@ -434,7 +435,7 @@ def test_partial_dependence_easy_target(est, power):
434435
est, features=[target_variable], X=X, grid_resolution=1000, kind="average"
435436
)
436437

437-
new_X = pdp["values"][0].reshape(-1, 1)
438+
new_X = pdp["pdp_values"][0].reshape(-1, 1)
438439
new_y = pdp["average"][0]
439440
# add polynomial features if needed
440441
new_X = PolynomialFeatures(degree=power).fit_transform(new_X)
@@ -654,7 +655,7 @@ def test_partial_dependence_sample_weight():
654655

655656
pdp = partial_dependence(clf, X, features=[1], kind="average")
656657

657-
assert np.corrcoef(pdp["average"], pdp["values"])[0, 1] > 0.99
658+
assert np.corrcoef(pdp["average"], pdp["pdp_values"])[0, 1] > 0.99
658659

659660

660661
def test_hist_gbdt_sw_not_supported():
@@ -692,8 +693,8 @@ def test_partial_dependence_pipeline():
692693
)
693694
assert_allclose(pdp_pipe["average"], pdp_clf["average"])
694695
assert_allclose(
695-
pdp_pipe["values"][0],
696-
pdp_clf["values"][0] * scaler.scale_[features] + scaler.mean_[features],
696+
pdp_pipe["pdp_values"][0],
697+
pdp_clf["pdp_values"][0] * scaler.scale_[features] + scaler.mean_[features],
697698
)
698699

699700

@@ -761,11 +762,11 @@ def test_partial_dependence_dataframe(estimator, preprocessor, features):
761762
if preprocessor is not None:
762763
scaler = preprocessor.named_transformers_["standardscaler"]
763764
assert_allclose(
764-
pdp_pipe["values"][1],
765-
pdp_clf["values"][1] * scaler.scale_[1] + scaler.mean_[1],
765+
pdp_pipe["pdp_values"][1],
766+
pdp_clf["pdp_values"][1] * scaler.scale_[1] + scaler.mean_[1],
766767
)
767768
else:
768-
assert_allclose(pdp_pipe["values"][1], pdp_clf["values"][1])
769+
assert_allclose(pdp_pipe["pdp_values"][1], pdp_clf["pdp_values"][1])
769770

770771

771772
@pytest.mark.parametrize(
@@ -796,7 +797,7 @@ def test_partial_dependence_feature_type(features, expected_pd_shape):
796797
pipe, df, features=features, grid_resolution=10, kind="average"
797798
)
798799
assert pdp_pipe["average"].shape == expected_pd_shape
799-
assert len(pdp_pipe["values"]) == len(pdp_pipe["average"].shape) - 1
800+
assert len(pdp_pipe["pdp_values"]) == len(pdp_pipe["average"].shape) - 1
800801

801802

802803
@pytest.mark.parametrize(
@@ -836,3 +837,31 @@ def test_kind_average_and_average_of_individual(Estimator, data):
836837
pdp_ind = partial_dependence(est, X=X, features=[1, 2], kind="individual")
837838
avg_ind = np.mean(pdp_ind["individual"], axis=1)
838839
assert_allclose(avg_ind, pdp_avg["average"])
840+
841+
842+
# TODO(1.5): Remove when bunch values is deprecated in 1.5
843+
def test_partial_dependence_bunch_values_deprecated():
844+
"""Test that deprecation warning is raised when values is accessed."""
845+
846+
est = LogisticRegression()
847+
(X, y), _ = binary_classification_data
848+
est.fit(X, y)
849+
850+
pdp_avg = partial_dependence(est, X=X, features=[1, 2], kind="average")
851+
852+
msg = (
853+
"Key: 'values', is deprecated in 1.3 and will be "
854+
"removed in 1.5. Please use 'pdp_values' instead"
855+
)
856+
857+
with warnings.catch_warnings():
858+
# Does not raise warnings with "pdp_values"
859+
warnings.simplefilter("error", FutureWarning)
860+
pdp_values = pdp_avg["pdp_values"]
861+
862+
with pytest.warns(FutureWarning, match=msg):
863+
# Warns for "values"
864+
values = pdp_avg["values"]
865+
866+
# "values" and "pdp_values" are the same object
867+
assert values is pdp_values

sklearn/utils/_bunch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
3+
14
class Bunch(dict):
25
"""Container object exposing keys as attributes.
36
@@ -24,6 +27,22 @@ class Bunch(dict):
2427
def __init__(self, **kwargs):
2528
super().__init__(kwargs)
2629

30+
# Map from deprecated key to warning message
31+
self.__dict__["_deprecated_key_to_warnings"] = {}
32+
33+
def __getitem__(self, key):
34+
if key in self.__dict__.get("_deprecated_key_to_warnings", {}):
35+
warnings.warn(
36+
self._deprecated_key_to_warnings[key],
37+
FutureWarning,
38+
)
39+
return super().__getitem__(key)
40+
41+
def _set_deprecated(self, value, *, new_key, deprecated_key, warning_message):
42+
"""Set key in dictionary to be deprecated with its warning message."""
43+
self.__dict__["_deprecated_key_to_warnings"][deprecated_key] = warning_message
44+
self[new_key] = self[deprecated_key] = value
45+
2746
def __setattr__(self, key, value):
2847
self[key] = value
2948

sklearn/utils/tests/test_bunch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import warnings
2+
3+
import numpy as np
4+
import pytest
5+
6+
from sklearn.utils import Bunch
7+
8+
9+
def test_bunch_attribute_deprecation():
10+
"""Check that bunch raises deprecation message with `__getattr__`."""
11+
bunch = Bunch()
12+
values = np.asarray([1, 2, 3])
13+
msg = (
14+
"Key: 'values', is deprecated in 1.3 and will be "
15+
"removed in 1.5. Please use 'pdp_values' instead"
16+
)
17+
bunch._set_deprecated(
18+
values, new_key="pdp_values", deprecated_key="values", warning_message=msg
19+
)
20+
21+
with warnings.catch_warnings():
22+
# Does not warn for "pdp_values"
23+
warnings.simplefilter("error")
24+
v = bunch["pdp_values"]
25+
26+
assert v is values
27+
28+
with pytest.warns(FutureWarning, match=msg):
29+
# Warns for "values"
30+
v = bunch["values"]
31+
32+
assert v is values

0 commit comments

Comments
 (0)
0