8000 Add test that score takes y, fix KMeans, FIX pipeline compatibility o… · scikit-learn/scikit-learn@d55d272 · GitHub
[go: up one dir, main page]

Skip to content

Commit d55d272

Browse files
committed
Add test that score takes y, fix KMeans, FIX pipeline compatibility of clustering algorithms!
1 parent ac7c88c commit d55d272

File tree

7 files changed

+28
-32
lines changed

7 files changed

+28
-32
lines changed

sklearn/cluster/affinity_propagation_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def __init__(self, damping=.5, max_iter=200, convergence_iter=15,
269269
def _pairwise(self):
270270
return self.affinity == "precomputed"
271271

272-
def fit(self, X):
272+
def fit(self, X, y=None):
273273
""" Create affinity matrix from negative euclidean distances, then
274274
apply affinity propagation clustering.
275275

sklearn/cluster/hierarchical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def __init__(self, n_clusters=2, affinity="euclidean",
680680
self.affinity = affinity
681681
self.pooling_func = pooling_func
682682

683-
def fit(self, X):
683+
def fit(self, X, y=None):
684684
"""Fit the hierarchical clustering on the data
685685
686686
Parameters

sklearn/cluster/k_means_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def predict(self, X):
864864
x_squared_norms = row_norms(X, squared=True)
865865
return _labels_inertia(X, x_squared_norms, self.cluster_centers_)[0]
866866

867-
def score(self, X):
867+
def score(self, X, y=None):
868868
"""Opposite of the value of X on the K-means objective.
869869
870870
Parameters

sklearn/cluster/mean_shift_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __init__(self, bandwidth=None, seeds=None, bin_seeding=False,
320320
self.cluster_all = cluster_all
321321
self.min_bin_freq = min_bin_freq
322322

323-
def fit(self, X):
323+
def fit(self, X, y=None):
324324
"""Perform clustering.
325325
326326
Parameters

sklearn/cluster/spectral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def __init__(self, n_clusters=8, eigen_solver=None, random_state=None,
405405
self.coef0 = coef0
406406
self.kernel_params = kernel_params
407407

408-
def fit(self, X):
408+
def fit(self, X, y=None):
409409
"""Creates an affinity matrix for X using the selected affinity,
410410
then applies spectral clustering to this affinity matrix.
411411

sklearn/tests/test_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
check_regressor_data_not_an_array,
5656
check_transformer_data_not_an_array,
5757
check_transformer_n_iter,
58+
check_fit_score_takes_y,
5859
check_non_transformer_estimators_n_iter,
5960
CROSS_DECOMPOSITION)
6061

@@ -87,6 +88,8 @@ def test_non_meta_estimators():
8788
estimators = all_estimators(type_filter=['classifier', 'regressor',
8889
'transformer', 'cluster'])
8990
for name, Estimator in estimators:
91+
if hasattr(Estimator, "score") and name not in CROSS_DECOMPOSITION:
92+
yield check_fit_score_takes_y, name, Estimator
9093
if name not in CROSS_DECOMPOSITION + ['Imputer']:
9194
# Test that all estimators check their input for NaN's and infs
9295
yield check_estimators_nan_inf, name, Estimator

sklearn/utils/estimator_checks.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from sklearn.utils.testing import SkipTest
2424
from sklearn.utils.testing import check_skip_travis
2525

26-
from sklearn.base import (clone, ClusterMixin, ClassifierMixin, RegressorMixin,
27-
TransformerMixin)
26+
from sklearn.base import clone, ClassifierMixin
2827
from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score
2928

3029
from sklearn.lda import LDA
@@ -44,13 +43,6 @@
4443
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
4544

4645

47-
def is_supervised(estimator):
48-
return (isinstance(estimator, ClassifierMixin)
49-
or isinstance(estimator, RegressorMixin)
50-
# transformers can all take a y
51-
or isinstance(estimator, TransformerMixin))
52-
53-
5446
def _boston_subset(n_samples=200):
5547
global BOSTON
5648
if BOSTON is None:
@@ -131,10 +123,7 @@ def check_estimator_sparse_data(name, Estimator):
131123
set_fast_parameters(estimator)
132124
# fit and predict
133125
try:
134-
if is_supervised(estimator):
135-
estimator.fit(X, y)
136-
else:
137-
estimator.fit(X)
126+
estimator.fit(X, y)
138127
if hasattr(estimator, "predict"):
139128
estimator.predict(X)
140129
if hasattr(estimator, 'predict_proba'):
@@ -252,6 +241,21 @@ def _check_transformer(name, Transformer, X, y):
252241
assert_raises(ValueError, transformer.transform, X.T)
253242

254243

244+
def check_fit_score_takes_y(name, Estimator):
245+
# check that all estimators accept an optional y
246+
# in fit and score so they can be used in pipelines
247+
rnd = np.random.RandomState(0)
248+
X = rnd.uniform(size=(10, 3))
249+
y = (X[:, 0] * 4).astype(np.int)
250+
y = multioutput_estimator_convert_y_2d(name, y)
251+
with warnings.catch_warnings(record=True):
252+
estimator = Estimator()
253+
set_fast_parameters(estimator)
254+
set_random_state(estimator)
255+
estimator.fit(X, y)
256+
estimator.score(X, y)
257+
258+
255259
def check_estimators_nan_inf(name, Estimator):
256260
rnd = np.random.RandomState(0)
257261
X_train_finite = rnd.uniform(size=(10, 3))
@@ -275,10 +279,7 @@ def check_estimators_nan_inf(name, Estimator):
275279
set_random_state(estimator, 1)
276280
# try to fit
277281
try:
278-
if issubclass(Estimator, ClusterMixin):
279-
estimator.fit(X_train)
280-
else:
281-
estimator.fit(X_train, y)
282+
estimator.fit(X_train, y)
282283
except ValueError as e:
283284
if 'inf' not in repr(e) and 'NaN' not in repr(e):
284285
print(error_string_fit, Estimator, e)
@@ -291,12 +292,7 @@ def check_estimators_nan_inf(name, Estimator):
291292
else:
292293
raise AssertionError(error_string_fit, Estimator)
293294
# actually fit
294-
if issubclass(Estimator, ClusterMixin):
295-
# All estimators except clustering algorithm
296-
# support fitting with (optional) y
297-
estimator.fit(X_train_finite)
298-
else:
299-
estimator.fit(X_train_finite, y)
295+
estimator.fit(X_train_finite, y)
300296

301297
# predict
302298
if hasattr(estimator, "predict"):
@@ -833,10 +829,7 @@ def check_estimators_overwrite_params(name, Estimator):
833829
set_random_state(estimator)
834830

835831
params = estimator.get_params()
836-
if is_supervised(estimator):
837-
estimator.fit(X, y)
838-
else:
839-
estimator.fit(X)
832+
estimator.fit(X, y)
840833
new_params = estimator.get_params()
841834
for k, v in params.items():
842835
assert_false(np.any(new_params[k] != v),

0 commit comments

Comments
 (0)
0