10000 Merge pull request #5057 from amueller/fix_2d_y_svm_regressor · scikit-learn/scikit-learn@3a6c4dc · GitHub
[go: up one dir, main page]

Skip to content

Commit 3a6c4dc

Browse files
committed
Merge pull request #5057 from amueller/fix_2d_y_svm_regressor
[MRG + 2] FIX make sure we handle y.shape = (n_samples, 1) consistently in regressors.
2 parents cec3bf9 + 35927df commit 3a6c4dc

File tree

6 files changed

+44
-29
lines changed

6 files changed

+44
-29
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ Bug fixes
123123
- Fixed bug in :class:`ensemble.forest.ForestClassifier` while computing
124124
oob_score and X is a sparse.csc_matrix. By `Ankur Ankan`_.
125125

126+
- All regressors now consistently handle and warn when given ``y`` that is of
127+
shape ``(n_samples, 1)``. By `Andreas Müller`_.
128+
126129
API changes summary
127130
-------------------
128131

sklearn/linear_model/coordinate_descent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .base import center_data, sparse_center_data
1818
from ..utils import check_array, check_X_y, deprecated
1919
from ..utils.validation import check_random_state
20-
from ..cross_validation import check_cv
20+
from ..cross_validation import check_cv, column_or_1d
2121
from ..externals.joblib import Parallel, delayed
2222
from ..externals import six
2323
from ..externals.six.moves import xrange
@@ -1020,9 +1020,10 @@ def fit(self, X, y):
10201020
model = ElasticNet()
10211021
else:
10221022
model = Lasso()
1023-
if y.ndim > 1:
1023+
if y.ndim > 1 and y.shape[1] > 1:
10241024
raise ValueError("For multi-task outputs, use "
10251025
"MultiTask%sCV" % (model_str))
1026+
y = column_or_1d(y, warn=True)
10261027
else:
10271028
if sparse.isspmatrix(X):
10281029
raise TypeError("X should be dense but a sparse matrix was"

sklearn/linear_model/least_angle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ def fit(self, X, y, copy_X=True):
12951295
returns an instance of self.
12961296
"""
12971297
self.fit_path = True
1298-
X, y = check_X_y(X, y, multi_output=True, y_numeric=True)
1298+
X, y = check_X_y(X, y, y_numeric=True)
12991299

13001300
X, y, Xmean, ymean, Xstd = LinearModel._center_data(
13011301
X, y, self.fit_intercept, self.normalize, self.copy_X)

sklearn/linear_model/theil_sen.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
from .base import LinearModel
2121
from ..base import RegressorMixin
22-
from ..utils import check_array, check_random_state, ConvergenceWarning
23-
from ..utils import check_consistent_length, _get_n_jobs
22+
from ..utils import check_random_state, ConvergenceWarning
23+
from ..utils import check_X_y, _get_n_jobs
2424
from ..utils.random import choice
2525
from ..externals.joblib import Parallel, delayed
2626
from ..externals.six.moves import xrange as range
@@ -343,9 +343,7 @@ def fit(self, X, y):
343343
self : returns an instance of self.
344344
"""
345345
random_state = check_random_state(self.random_state)
346-
X = check_array(X)
347-
y = check_array(y, ensure_2d=False)
348-
check_consistent_length(X, y)
346+
X, y = check_X_y(X, y, y_numeric=True)
349347
n_samples, n_features = X.shape
350348
n_subsamples, self.n_subpopulation_ = self._check_subparams(n_samples,
351349
n_features)

sklearn/svm/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def __init__(self, impl, kernel, degree, gamma, coef0,
8080
"gamma='%s' as of 0.17. Backward compatibility"
8181
" for gamma=%s will be removed in %s")
8282
invalid_gamma = 0.0
83-
warnings.warn(msg % (invalid_gamma, "auto",
84-
invalid_gamma, "0.18"), DeprecationWarning)
83+
warnings.warn(msg % (invalid_gamma, "auto", invalid_gamma, "0.18"),
84+
DeprecationWarning)
8585

8686
self._impl = impl
8787
self.kernel = kernel
@@ -171,8 +171,8 @@ def fit(self, X, y, sample_weight=None):
171171
% (sample_weight.shape, X.shape))
172172

173173
# FIXME remove (self.gamma == 0) in 0.18
174-
if (self.kernel in ['poly', 'rbf']) and ((self.gamma == 0)
175-
or (self.gamma == 'auto')):
174+
if (self.kernel in ['poly', 'rbf']) and ((self.gamma == 0) or
175+
(self.gamma == 'auto')):
176176
# if custom gamma is not provided ...
177177
self._gamma = 1.0 / X.shape[1]
178178
elif self.gamma == 'auto':
@@ -212,7 +212,7 @@ def _validate_targets(self, y):
212212
# XXX this is ugly.
213213
# Regression models should not have a class_weight_ attribute.
214214
self.class_weight_ = np.empty(0)
215-
return np.asarray(y, dtype=np.float64, order='C')
215+
return column_or_1d(y, warn=True).astype(np.float64)
216216

217217
def _warn_from_fit_status(self):
218218
assert self.fit_status_ in (0, 1)

sklearn/utils/estimator_checks.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@
4848

4949
BOSTON = None
5050
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
51+
MULTI_OUTPUT = ['CCA', 'DecisionTreeRegressor', 'ElasticNet',
52+
'ExtraTreeRegressor', 'ExtraTreesRegressor', 'GaussianProcess',
53+
'KNeighborsRegressor', 'KernelRidge', 'Lars', 'Lasso',
54+
'LassoLars', 'LinearRegression', 'MultiTaskElasticNet',
55+
'MultiTaskElasticNetCV', 'MultiTaskLasso', 'MultiTaskLassoCV',
56+
'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression',
57+
'RANSACRegressor', 'RadiusNeighborsRegressor',
58+
'RandomForestRegressor', 'Ridge', 'RidgeCV']
5159

5260

5361
def _yield_non_meta_checks(name, Estimator):
@@ -100,8 +108,7 @@ def _yield_classifier_checks(name, Classifier):
100108
# We don't raise a warning in these classifiers, as
101109
# the column y interface is used by the forests.
102110

103-
# test if classifiers can cope with y.shape = (n_samples, 1)
104-
yield check_classifiers_input_shapes
111+
yield check_supervised_y_2d
105112
# test if NotFittedError is raised
106113
yield check_estimators_unfitted
107114
if 'class_weight' in Classifier().get_params().keys():
@@ -116,6 +123,7 @@ def _yield_regressor_checks(name, Regressor):
116123
yield check_regressor_data_not_an_array
117124
yield check_estimators_partial_fit_n_features
118125
yield check_regressors_no_decision_function
126+
yield check_supervised_y_2d
119127
if name != 'CCA':
120128
# check that the regressor handles int input
121129
yield check_regressors_int
@@ -831,31 +839,36 @@ def check_estimators_unfitted(name, Estimator):
831839
est.predict_log_proba, X)
832840

833841

834-
def check_classifiers_input_shapes(name, Classifier):
835-
iris = load_iris()
836-
X, y = iris.data, iris.target
837-
X, y = shuffle(X, y, random_state=1)
838-
X = StandardScaler().fit_transform(X)
842+
def check_supervised_y_2d(name, Estimator):
843+
if "MultiTask" in name:
844+
# These only work on 2d, so this test makes no sense
845+
return
846+
rnd = np.random.RandomState(0)
847+
X = rnd.uniform(size=(10, 3))
848+
y = np.arange(10) % 3
839849
# catch deprecation warnings
840850
with warnings.catch_warnings(record=True):
841-
classifier = Classifier()
842-
set_fast_parameters(classifier)
843-
set_random_state(classifier)
851+
estimator = Estimator()
852+
set_fast_parameters(estimator)
853+
set_random_state(estimator)
844854
# fit
845-
classifier.fit(X, y)
846-
y_pred = classifier.predict(X)
855+
estimator.fit(X, y)
856+
y_pred = estimator.predict(X)
847857

848-
set_random_state(classifier)
858+
set_random_state(estimator)
849859
# Check that when a 2D y is given, a DataConversionWarning is
850860
# raised
851861
with warnings.catch_warnings(record=True) as w:
852862
warnings.simplefilter("always", DataConversionWarning)
853863
warnings.simplefilter("ignore", RuntimeWarning)
854-
classifier.fit(X, y[:, np.newaxis])
864+
estimator.fit(X, y[:, np.newaxis])
865+
y_pred_2d = estimator.predict(X)
855866
msg = "expected 1 DataConversionWarning, got: %s" % (
856867
", ".join([str(w_x) for w_x in w]))
857-
assert_equal(len(w), 1, msg)
858-
assert_array_equal(y_pred, classifier.predict(X))
868+
if name not in MULTI_OUTPUT:
869+
# check that we warned if we don't support multi-output
870+
assert_equal(len(w), 1, msg)
871+
assert_array_almost_equal(y_pred.ravel(), y_pred_2d.ravel())
859872

860873

861874
def check_classifiers_classes(name, Classifier):

0 commit comments

Comments
 (0)
0