8000 ENH Support sample weights in PartialDependenceDisplay.from_estimator… · scikit-learn/scikit-learn@25e179a · GitHub
[go: up one dir, main page]

Skip to content

Commit 25e179a

Browse files
vitalisetjeremiedbb
authored andcommitted
ENH Support sample weights in PartialDependenceDisplay.from_estimator (#26644)
1 parent 2a45ddd commit 25e179a

File tree

3 files changed

+55
-10
lines changed

3 files changed

+55
-10
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,11 @@ Changelog
421421
.........................
422422

423423
- |Enhancement| Added support for `sample_weight` in
424-
:func:`inspection.partial_dependence`. This allows for weighted averaging when
425-
aggregating for each value of the grid we are making the inspection on. The
426-
option is only available when `method` is set to `brute`. :pr:`25209`
427-
by :user:`Carlo Lemos <vitaliset>`.
424+
:func:`inspection.partial_dependence` and
425+
:meth:`inspection.PartialDependenceDisplay.from_estimator`. This allows for
426+
weighted averaging when aggregating for each value of the grid we are making the
427+
inspection on. The option is only available when `method` is set to `brute`.
428+
:pr:`25209` and :pr:`26644` by :user:`Carlo Lemos <vitaliset>`.
428429

429430
- |API| :func:`inspection.partial_dependence` returns a :class:`utils.Bunch` with
430431
new key: `grid_values`. The `values` key is deprecated in favor of `grid_values`

sklearn/inspection/_plot/partial_dependence.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ class PartialDependenceDisplay:
8686
8787
.. note::
8888
The fast ``method='recursion'`` option is only available for
89-
``kind='average'``. Plotting individual dependencies requires using
90-
the slower ``method='brute'`` option.
89+
`kind='average'` and `sample_weights=None`. Computing individual
90+
dependencies and doing weighted averages requires using the slower
91+
`method='brute'`.
9192
9293
.. versionadded:: 0.24
9394
Add `kind` parameter with `'average'`, `'individual'`, and `'both'`
@@ -247,6 +248,7 @@ def from_estimator(
247248
X,
248249
features,
249250
*,
251+
sample_weight=None,
250252
categorical_features=None,
251253
feature_names=None,
252254
target=None,
@@ -337,6 +339,14 @@ def from_estimator(
337339
with `kind='average'`). Each tuple must be of size 2.
338340
If any entry is a string, then it must be in ``feature_names``.
339341
342+
sample_weight : array-like of shape (n_samples,), default=None
343+
Sample weights are used to calculate weighted means when averaging the
344+
model output. If `None`, then samples are equally weighted. If
345+
`sample_weight` is not `None`, then `method` will be set to `'brute'`.
346+
Note that `sample_weight` is ignored for `kind='individual'`.
347+
348+
.. versionadded:: 1.3
349+
340350
categorical_features : array-like of shape (n_features,) or shape \
341351
(n_categorical_features,), dtype={bool, int, str}, default=None
342352
Indicates the categorical features.
@@ -409,7 +419,8 @@ def from_estimator(
409419
computationally intensive.
410420
411421
- `'auto'`: the `'recursion'` is used for estimators that support it,
412-
and `'brute'` is used otherwise.
422+
and `'brute'` is used otherwise. If `sample_weight` is not `None`,
423+
then `'brute'` is used regardless of the estimator.
413424
414425
Please see :ref:`this note <pdp_method_differences>` for
415426
differences between the `'brute'` and `'recursion'` method.
@@ -464,9 +475,10 @@ def from_estimator(
464475
- ``kind='average'`` results in the traditional PD plot;
465476
- ``kind='individual'`` results in the ICE plot.
466477
467-
Note that the fast ``method='recursion'`` option is only available for
468-
``kind='average'``. Plotting individual dependencies requires using the
469-
slower ``method='brute'`` option.
478+
Note that the fast `method='recursion'` option is only available for
479+
`kind='average'` and `sample_weights=None`. Computing individual
480+
dependencies and doing weighted averages requires using the slower
481+
`method='brute'`.
470482
471483
centered : bool, default=False
472484
If `True`, the ICE and PD lines will start at the origin of the
@@ -693,6 +705,7 @@ def from_estimator(
693705
estimator,
694706
X,
695707
fxs,
708+
sample_weight=sample_weight,
696709
feature_names=feature_names,
697710
categorical_features=categorical_features,
698711
response_method=response_method,

sklearn/inspection/_plot/tests/test_plot_partial_dependence.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,3 +1086,34 @@ def test_partial_dependence_display_kind_centered_interaction(
10861086
)
10871087

10881088
assert all([ln._y[0] == 0.0 for ln in disp.lines_.ravel() if ln is not None])
1089+
1090+
1091+
def test_partial_dependence_display_with_constant_sample_weight(
1092+
pyplot,
1093+
clf_diabetes,
1094+
diabetes,
1095+
):
1096+
"""Check that the utilization of a constant sample weight maintains the
1097+
standard behavior.
1098+
"""
1099+
disp = PartialDependenceDisplay.from_estimator(
1100+
clf_diabetes,
1101+
diabetes.data,
1102+
[0, 1],
1103+
kind="average",
1104+
method="brute",
1105+
)
1106+
1107+
sample_weight = np.ones_like(diabetes.target)
1108+
disp_sw = PartialDependenceDisplay.from_estimator(
1109+
clf_diabetes,
1110+
diabetes.data,
1111+
[0, 1],
1112+
sample_weight=sample_weight,
1113+
kind="average",
1114+
method="brute",
1115+
)
1116+
1117+
assert np.array_equal(
1118+
disp.pd_results[0]["average"], disp_sw.pd_results[0]["average"]
1119+
)

0 commit comments

Comments
 (0)
0