8000 TST add some test for the new behaviour · scikit-learn/scikit-learn@9bb6ebb · GitHub
[go: up one dir, main page]

Skip to content

Commit 9bb6ebb

Browse files
committed
TST add some test for the new behaviour
1 parent fc45b4f commit 9bb6ebb

File tree

4 files changed

+72
-14
lines changed

4 files changed

+72
-14
lines changed

doc/whats_new/v0.24.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ Changelog
7676
redundant with the `dictionary` attribute and constructor parameter.
7777
:pr:`17679` by :user:`Xavier Dupré <sdpython>`.
7878

79+
:mod:`sklearn.dummy`
80+
....................
81+
82+
- |Enhancement| Add a parameter `interpolation` to
83+
:class:`dummy.DummyRegressor` to choose the type of interpolation with the
84+
strategy `median` and `quantile`. Beware that the interpolation will always
85+
be `'linear'` with and without `sample_weight` in the future.
86+
:pr:`xxx` by :user:`Guillaume Lemaitre <glemaitre>`.
87+
7988
:mod:`sklearn.ensemble`
8089
.......................
8190

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ addopts =
1212
--ignore examples
1313
--ignore maint_tools
1414
--doctest-modules
15-
--disable-pytest-warnings
15+
# --disable-pytest-warnings
1616
-rxXs
1717

1818
filterwarnings =

sklearn/dummy.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,15 @@ class DummyRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
430430
* `"higher"`: `j`;
431431
* `"nearest"`: `i` or `j`, whichever is nearest.
432432
433-
When
433+
By default, if `sample_weight` is `None`, `interpolation="linear"`,
434+
otherwise `interpolation="nearest"`.
434435
435436
.. versionadded: 0.24
436437
438+
.. versionchanged:: 0.24
439+
`interpolation` will be `"linear"` whether the regressor is fitted
440+
with or without `sample_weight` from 0.26.
441+
437442
Attributes
438443
----------
439444
constant_ : array of shape (1, n_outputs)
@@ -501,13 +506,26 @@ def fit(self, X, y, sample_weight=None):
501506

502507
if sample_weight is not None:
503508
sample_weight = _check_sample_weight(sample_weight, X)
504-
interpolation = (
505-
"nearest" if self.interpolation is None else self.interpolation
506-
)
507-
else:
508-
interpolation = (
509-
"linear" if self.interpolation is None else self.interpolation
510-
)
509+
510+
# FIXME: change the default interpolation to "linear" in 0.26
511+
if self.strategy in ("median", "quantile"):
512+
if sample_weight is not None:
513+
if self.interpolation is None:
514+
warnings.warn(
515+
"From 0.26 and onward, interpolation will be 'linear' "
516+
"by default when fitting with some sample weights. You"
517+
" can force `interpolation='linear'` to get the new "
518+
"behaviour and silence this warning.",
519+
FutureWarning
520+
)
521+
interpolation = "nearest"
522+
else:
523+
interpolation = self.interpolation
524+
else:
525+
interpolation = (
526+
"linear" if self.interpolation is None
527+
else self.interpolation
528+
)
511529

512530
if self.strategy == "mean":
513531
self.constant_ = np.average(y, axis=0, weights=sample_weight)

sklearn/tests/test_dummy.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,15 +665,19 @@ def test_dummy_regressor_sample_weight(n_samples=10):
665665
est = DummyRegressor(strategy="mean").fit(X, y, sample_weight)
666666
assert est.constant_ == np.average(y, weights=sample_weight)
667667

668-
est = DummyRegressor(strategy="median").fit(X, y, sample_weight)
668+
interpolation = "linear"
669+
est = DummyRegressor(strategy="median", interpolation=interpolation)
670+
est.fit(X, y, sample_weight)
669671
assert est.constant_ == _weighted_percentile(
670-
y, sample_weight, 50., interpolation="nearest",
672+
y, sample_weight, 50., interpolation=interpolation,
671673
)
672674

673-
est = DummyRegressor(strategy="quantile", quantile=.95).fit(X, y,
674-
sample_weight)
675+
est = DummyRegressor(
676+
strategy="quantile", quantile=.95, interpolation=interpolation,
677+
)
678+
est.fit(X, y, sample_weight)
675679
assert est.constant_ == _weighted_percentile(
676-
y, sample_weight, 95., interpolation="nearest",
680+
y, sample_weight, 95., interpolation=interpolation,
677681
)
678682

679683

@@ -771,6 +775,7 @@ def test_n_features_in_(Dummy):
771775
assert d.n_features_in_ is None
772776

773777

778+
@pytest.mark.filterwarnings("ignore:From 0.26 and onward, interpolation will")
774779
@pytest.mark.parametrize(
775780
"strategy, quantile", [("median", 0.5), ("quantile", 0.9)]
776781
)
@@ -803,3 +808,29 @@ def test_dummy_regressor_default_legacy_behaviour(strategy, quantile):
803808
y, sample_weight, percentile=percentile, interpolation="nearest",
804809
)
805810
)
811+
812+
813+
@pytest.mark.parametrize(
814+
"strategy, quantile", [("median", 0.5), ("quantile", 0.9)]
815+
)
816+
@pytest.mark.parametrize(
817+
"interpolation, WarningType, expected_n_warnings",
818+
[(None, FutureWarning, 1), ("linear", None, 0)]
819+
)
820+
def test_dummy_regressort_future_warning_interpolation(
821+
strategy, quantile, interpolation, WarningType, expected_n_warnings,
822+
):
823+
rng = np.random.RandomState(seed=1)
824+
825+
n_samples = 100
826+
X = [[0]] * n_samples
827+
y = rng.rand(n_samples)
828+
sample_weight = rng.rand(n_samples)
829+
830+
regressor = DummyRegressor(
831+
strategy=strategy, quantile=quantile, interpolation=interpolation,
832+
)
833+
834+
with pytest.warns(WarningType) as record:
835+
regressor.fit(X, y, sample_weight=sample_weight)
836+
assert len(record) == expected_n_warnings

0 commit comments

Comments
 (0)
0