diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c2c01caf68e2e..337bbc20b3928 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -103,7 +103,7 @@ def _yield_non_meta_checks(name, Estimator): yield check_sparsify_coefficients yield check_estimator_sparse_data - + yield check_estimator_sparse_dense # Test that estimators can be pickled, and once pickled # give the same answer as before. yield check_estimators_pickle @@ -312,6 +312,11 @@ def set_testing_parameters(estimator): if not isinstance(estimator, ProjectedGradientNMF): estimator.set_params(solver='cd') + if "KNeighbors" in estimator.__class__.__name__ : + # Override the default 'auto' for sparse dense equivalence + # since only 'brute' algo is used for sparse see #1572 + estimator.set_params(algorithm='brute') + class NotAnArray(object): " An object that is convertable to an array" @@ -344,6 +349,7 @@ def check_estimator_sparse_data(name, Estimator): estimator = Estimator() set_testing_parameters(estimator) # fit and predict + try: with ignore_warnings(category=DeprecationWarning): estimator.fit(X, y) @@ -1554,3 +1560,52 @@ def check_classifiers_regression_target(name, Estimator): e = Estimator() msg = 'Unknown label type: ' assert_raises_regex(ValueError, msg, e.fit, X, y) + + +def check_estimator_sparse_dense(name, Estimator): + rng = np.random.RandomState(0) + X = rng.rand(40, 10) + X[X < .8] = 0 + X_csr = sparse.csr_matrix(X) + y = (4 * rng.rand(40)).astype(np.int) + for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']: + X_sp = X_csr.asformat(sparse_format) + # catch deprecation warnings + with ignore_warnings(category=DeprecationWarning): + if name in ['Scaler', 'StandardScaler']: + estimator = Estimator(with_mean=False) + estimator_sp = Estimator(with_mean=False) + else: + estimator = Estimator() + estimator_sp = Estimator() + set_testing_parameters(estimator) + set_testing_parameters(estimator_sp) + set_random_state(estimator) + set_random_state(estimator_sp) + #print(np.where(X!=X_sp.toarray())) + # fit and predict + try: + with ignore_warnings(category=DeprecationWarning): + estimator_sp.fit(X_sp, y) + estimator.fit(X, y) + if hasattr(estimator, "predict"): + pred = estimator.predict(X) + pred_sp = estimator_sp.predict(X_sp) + assert_array_almost_equal(pred, pred_sp, 2) + assert_equal(pred.shape, pred_sp.shape) + if hasattr(estimator, 'predict_proba'): + probs = estimator.predict_proba(X) + assert_equal(probs.shape, (X.shape[0], 4)) + except TypeError as e: + if 'sparse' not in repr(e): + print("Estimator %s doesn't seem to fail gracefully on " + "sparse data: error message state explicitly that " + "sparse input is not supported if this is not the case." + % name) + raise + except Exception: + print("Estimator %s doesn't seem to fail gracefully on " + "sparse data: it should raise a TypeError if sparse input " + "is explicitly not supported." % name) + raise +