8000 Inject safer test value for MeanShift.bandwidth in estimator checks (… · glemaitre/scikit-learn@9dc1a5f · GitHub
[go: up one dir, main page]

Skip to content

Commit 9dc1a5f

Browse files
jjerphanglemaitrejeremiedbb
committed
Inject safer test value for MeanShift.bandwidth in estimator checks (scikit-learn#21501)
This modify the test configuration so that it makes sense for when a sole sample is provided for MeanShift. This test was passing previously for this configuration but was not supposed to. The new implementation strategy for kneighbo 8000 rs which uses PairwiseDistancesArgKmin (see scikit-learn#21462) is numerically stabler for this case, motivating this modication. Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 3e678e8 commit 9dc1a5f

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,11 @@ def _set_checking_parameters(estimator):
632632
if "n_init" in params:
633633
# K-Means 69F2
634634
estimator.set_params(n_init=2)
635+
if name == "MeanShift":
636+
# In the case of check_fit2d_1sample, bandwidth is set to None and
637+
# is thus estimated. De facto it is 0.0 as a single sample is provided
638+
# and this makes the test fails. Hence we give it a placeholder value.
639+
estimator.set_params(bandwidth=1.0)
635640

636641
if name == "TruncatedSVD":
637642
# TruncatedSVD doesn't run with n_components = n_features

0 commit comments

Comments
 (0)
0