8000 [MRG+1] Add classes_ parameter to hyperparameter CV classes (#8295) · NelleV/scikit-learn@363558e · GitHub
[go: up one dir, main page]

8000 Skip to content

Commit 363558e

Browse files
Stephen HooverNelleV
authored andcommitted
[MRG+1] Add classes_ parameter to hyperparameter CV classes (scikit-learn#8295)
1 parent e765cc5 commit 363558e

File tree

4 files changed

+53
-7
lines changed

4 files changed

+53
-7
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@ Enhancements
6666
now uses significantly less memory when assigning data points to their
6767
nearest cluster center. :issue:`7721` by :user:`Jon Crall <Erotemic>`.
6868

69-
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
70-
that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661`
71-
by :user:`Alyssa Batula <abatula>` and :user:`Dylan Werner-Meier <unautre>`.
69+
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`,
70+
:class:`model_selection.RandomizedSearchCV`, :class:`grid_search.GridSearchCV`,
71+
and :class:`grid_search.RandomizedSearchCV` that matches the ``classes_``
72+
attribute of ``best_estimator_``. :issue:`7661` and :issue:`8295`
73+
by :user:`Alyssa Batula <abatula>`, :user:`Dylan Werner-Meier <unautre>`,
74+
and :user:`Stephen Hoover <stephen-hoover>`.
7275

7376
- The ``min_weight_fraction_leaf`` constraint in tree construction is now
7477
more efficient, taking a fast path to declare a node a leaf if its weight

sklearn/model_selection/_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,11 @@ def inverse_transform(self, Xt):
532532
self._check_is_fitted('inverse_transform')
533533
return self.best_estimator_.transform(Xt)
534534

535+
@property
536+
def classes_(self):
537+
self._check_is_fitted("classes_")
538+
return self.best_estimator_.classes_
539+
535540
def fit(self, X, y=None, groups=None, **fit_params):
536541
"""Run fit with all sets of parameters.
537542

sklearn/model_selection/tests/test_search.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from sklearn.metrics import roc_auc_score
6060
from sklearn.preprocessing import Imputer
6161
from sklearn.pipeline import Pipeline
62-
from sklearn.linear_model import SGDClassifier
62+
from sklearn.linear_model import Ridge, SGDClassifier
6363

6464
from sklearn.model_selection.tests.common import OneTimeSplitter
6565

@@ -73,6 +73,7 @@ def __init__(self, foo_param=0):
7373

7474
def fit(self, X, Y):
7575
assert_true(len(X) == len(Y))
76+
self.classes_ = np.unique(Y)
7677
return self
7778

7879
def predict(self, T):
@@ -323,6 +324,33 @@ def test_grid_search_groups():
323324
gs.fit(X, y)
324325

325326

327+
def test_classes__property():
328+
# Test that classes_ property matches best_estimator_.classes_
329+
X = np.arange(100).reshape(10, 10)
330+
y = np.array([0] * 5 + [1] * 5)
331+
Cs = [.1, 1, 10]
332+
333+
grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
334+
grid_search.fit(X, y)
335+
assert_array_equal(grid_search.best_estimator_.classes_,
336+
grid_search.classes_)
337+
338+
# Test that regressors do not have a classes_ attribute
339+
grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
340+
grid_search.fit(X, y)
341+
assert_false(hasattr(grid_search, 'classes_'))
342+
343+
# Test that the grid searcher has no classes_ attribute before it's fit
344+
grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
345+
assert_false(hasattr(grid_search, 'classes_'))
346+
347+
# Test that the grid searcher has no classes_ attribute without a refit
348+
grid_search = GridSearchCV(LinearSVC(random_state=0),
349+
{'C': Cs}, refit=False)
350+
grid_search.fit(X, y)
351+
assert_false(hasattr(grid_search, 'classes_'))
352+
353+
326354
def test_trivial_cv_results_attr():
327355
# Test search over a "grid" with only one point.
328356
# Non-regression test: grid_scores_ wouldn't be set by GridSearchCV.

sklearn/model_selection/tests/test_validation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from sklearn.datasets import make_multilabel_classification
6363

6464
from sklearn.model_selection.tests.common import OneTimeSplitter
65+
from sklearn.model_selection import GridSearchCV
6566

6667

6768
try:
@@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction():
914915
assert_array_almost_equal(preds_sparse, preds)
915916

916917

917-
def test_cross_val_predict_with_method():
918+
def check_cross_val_predict_with_method(est):
918919
iris = load_iris()
919920
X, y = iris.data, iris.target
920921
X, y = shuffle(X, y, random_state=0)
@@ -924,8 +925,6 @@ def test_cross_val_predict_with_method():
924925

925926
methods = ['decision_function', 'predict_proba', 'predict_log_proba']
926927
for method in methods:
927-
est = LogisticRegression()
928-
929928
predictions = cross_val_predict(est, X, y, method=method)
930929
assert_equal(len(predictions), len(y))
931930

@@ -955,6 +954,17 @@ def test_cross_val_predict_with_method():
955954
assert_array_equal(predictions, predictions_ystr)
956955

957956

957+
def test_cross_val_predict_with_method():
958+
check_cross_val_predict_with_method(LogisticRegression())
959+
960+
961+
def test_gridsearchcv_cross_val_predict_with_method():
962+
est = GridSearchCV(LogisticRegression(random_state=42),
963+
{'C': [0.1, 1]},
964+
cv=2)
965+
check_cross_val_predict_with_method(est)
966+
967+
958968
def get_expected_predictions(X, y, cv, classes, est, method):
959969

960970
expected_predictions = np.zeros([len(y), classes])

0 commit comments

Comments
 (0)
0