8000 Merge pull request #5234 from vighneshbirodkar/mcd_fix · scikit-learn/scikit-learn@8d273a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d273a1

Browse files
committed
Merge pull request #5234 from vighneshbirodkar/mcd_fix
[MRG + 1]Deprecating 1D inputs in fast_mcd
2 parents da9a7cd + f0121b7 commit 8d273a1

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

sklearn/covariance/robust_covariance.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,7 @@ def fast_mcd(X, support_fraction=None,
361361
"""
362362
random_state = check_random_state(random_state)
363363

364-
X = np.asarray(X)
365-
if X.ndim == 1:
366-
X = np.reshape(X, (1, -1))
367-
warnings.warn("Only one sample available. "
368-
"You may want to reshape your data array")
364+
X = check_array(X, ensure_min_samples=2, estimator='fast_mcd')
369365
n_samples, n_features = X.shape
370366

371367
# minimum breakdown value
@@ -609,7 +605,7 @@ def fit(self, X, y=None):
609605
Returns self.
610606
611607
"""
612-
X = check_array(X)
608+
X = check_array(X, ensure_min_samples=2, estimator='MinCovDet')
613609
random_state = check_random_state(self.random_state)
614610
n_samples, n_features = X.shape
615611
# check that the empirical covariance is full rank

sklearn/covariance/tests/test_robust_covariance.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99
from sklearn.utils.testing import assert_almost_equal
1010
from sklearn.utils.testing import assert_array_almost_equal
11-
from sklearn.utils.testing import assert_raises
11+
from sklearn.utils.testing import assert_raises, assert_warns
12+
from sklearn.utils.testing import assert_raise_message
1213
from sklearn.utils.validation import NotFittedError
1314

1415
from sklearn import datasets
1516
from sklearn.covariance import empirical_covariance, MinCovDet, \
1617
EllipticEnvelope
18+
from sklearn.covariance import fast_mcd
1719

1820
X = datasets.load_iris().data
1921
X_1d = X[:, 0]
@@ -40,6 +42,19 @@ def test_mcd():
4042
launch_mcd_on_dataset(500, 1, 100, 0.001, 0.001, 350)
4143

4244

45+
def test_fast_mcd_on_invalid_input():
46+
X = np.arange(100)
47+
assert_raise_message(ValueError, 'fast_mcd expects at least 2 samples',
48+
fast_mcd, X)
49+
50+
51+
def test_mcd_class_on_invalid_input():
52+
X = np.arange(100)
53+
mcd = MinCovDet()
54+
assert_raise_message(ValueError, 'MinCovDet expects at least 2 samples',
55+
mcd.fit, X)
56+
57+
4358
def launch_mcd_on_dataset(n_samples, n_features, n_outliers, tol_loc, tol_cov,
4459
tol_support):
4560

sklearn/utils/validation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
374374

375375
if ensure_2d:
376376
if array.ndim == 1:
377+
if ensure_min_samples >= 2:
378+
raise ValueError("%s expects at least 2 samples provided "
379+
"in a 2 dimensional array-like input"
380+
% estimator_name)
377381
warnings.warn(
378382
"Passing 1d arrays as data is deprecated in 0.17 and will"
379383
"raise ValueError in 0.19. Reshape your data either using "

0 commit comments

Comments
 (0)
0