8000 FIX validate parameter if `fit` in `KernelDensity` estimator (#21430) · scikit-learn/scikit-learn@aa675f9 · GitHub
[go: up one dir, main page]

Skip to content

Commit aa675f9

Browse files
authored
FIX validate parameter if fit in KernelDensity estimator (#21430)
Co-Authored-By: Lucy Jiménez lucy.jimenez.chem@gmail.com
1 parent 8d26900 commit aa675f9

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

doc/whats_new/v1.1.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ Changelog
9595
message when running in a jupyter notebook that is not trusted. :pr:`21316`
9696
by `Thomas Fan`_.
9797

98+
:mod:`sklearn.neighbors`
99+
........................
100+
101+
- |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit`
102+
instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva <DessyVV>` and
103+
:user:`Lucy Jimenez <LucyJimenez>`.
104+
98105
Code and Documentation Contributors
99106
-----------------------------------
100107

sklearn/neighbors/_kde.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,6 @@ def __init__(
135135
self.leaf_size = leaf_size
136136
self.metric_params = metric_params
137137

138-
# run the choose algorithm code so that exceptions will happen here
139-
# we're using clone() in the GenerativeBayes classifier,
140-
# so we can't do this kind of logic in __init__
141-
self._choose_algorithm(self.algorithm, self.metric)
142-
143-
if bandwidth <= 0:
144-
raise ValueError("bandwidth must be positive")
145-
if kernel not in VALID_KERNELS:
146-
raise ValueError("invalid kernel: '{0}'".format(kernel))
147-
148138
def _choose_algorithm(self, algorithm, metric):
149139
# given the algorithm string + metric string, choose the optimal
150140
# algorithm to compute the result.
@@ -188,7 +178,14 @@ def fit(self, X, y=None, sample_weight=None):
188178
self : object
189179
Returns the instance itself.
190180
"""
181+
191182
algorithm = self._choose_algorithm(self.algorithm, self.metric)
183+
184+
if self.bandwidth <= 0:
185+
raise ValueError("bandwidth must be positive")
186+
if self.kernel not in VALID_KERNELS:
187+
raise ValueError("invalid kernel: '{0}'".format(self.kernel))
188+
192189
X = self._validate_data(X, order="C", dtype=DTYPE)
193190

194191
if sample_weight is not None:

sklearn/neighbors/tests/test_kde.py

8000
Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,12 @@ def test_kde_algorithm_metric_choice(algorithm, metric):
107107
X = rng.randn(10, 2) # 2 features required for haversine dist.
108108
Y = rng.randn(10, 2)
109109

110+
kde = KernelDensity(algorithm=algorithm, metric=metric)
111+
110112
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics:
111113
with pytest.raises(ValueError):
112-
KernelDensity(algorithm=algorithm, metric=metric)
114+
kde.fit(X)
113115
else:
114-
kde = KernelDensity(algorithm=algorithm, metric=metric)
115116
kde.fit(X)
116117
y_dens = kde.score_samples(Y)
117118
assert y_dens.shape == Y.shape[:1]
@@ -126,16 +127,17 @@ def test_kde_score(n_samples=100, n_features=3):
126127

127128

128129
def test_kde_badargs():
130+
X = np.random.random((200, 10))
129131
with pytest.raises(ValueError):
130-
KernelDensity(algorithm="blah")
132+
KernelDensity(algorithm="blah").fit(X)
131133
with pytest.raises(ValueError):
132-
KernelDensity(bandwidth=0)
134+
KernelDensity(bandwidth=0).fit(X)
133135
with pytest.raises(ValueError):
134-
KernelDensity(kernel="blah")
136+
KernelDensity(kernel="blah").fit(X)
135137
with pytest.raises(ValueError):
136-
KernelDensity(metric="blah")
138+
KernelDensity(metric="blah").fit(X)
137139
with pytest.raises(ValueError):
138-
KernelDensity(algorithm="kd_tree", metric="blah")
140+
KernelDensity(algorithm="kd_tree", metric="blah").fit(X)
139141
kde = KernelDensity()
140142
with pytest.raises(ValueError):
141143
kde.fit(np.random.random((200, 10)), sample_weight=np.random.random((200, 10)))

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def test_transformers_get_feature_names_out(transformer):
412412
"FeatureUnion",
413413
"GridSearchCV",
414414
"HalvingGridSearchCV",
415-
"KernelDensity",
416415
"KernelPCA",
417416
"LabelBinarizer",
418417
"NuSVC",

0 commit comments

Comments
 (0)
0