8000 [MRG+1] Added unit test for adding classes_ property to GridSearchCV,… · paulha/scikit-learn@aaa338a · GitHub
[go: up one dir, main page]

Skip to content

Commit aaa338a

Browse files
abatulapaulha
authored andcommitted
[MRG+1] Added unit test for adding classes_ property to GridSearchCV, fixes scikit-learn#6298 (scikit-learn#7661)
* Fix issue scikit-learn#6298 Adds a "classes_" property to BaseSearchCV * Added test to ensure classes_ property is added to gridSearch correctly * Fixed formatting * Added test to ensure gridSearchCV with a regressor does not have a classes_ attribute * Fixed whitespace issues * Combined tests for the new GridSearchSV.classes_ property into a single test. * Removed trailing whitespace * Added what's new for pull request scikit-learn#7661 * Fixed formatting of update in what's new
1 parent 10f73c1 commit aaa338a

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ New features
1919
Enhancements
2020
............
2121

22+
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
23+
that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661
24+
<https://github.com/scikit-learn/scikit-learn/pull/7661>`_) by `Alyssa
25+
Batula`_ and `Dylan Werner-Meier`_.
26+
2227
- The ``min_weight_fraction_leaf`` constraint in tree construction is now
2328
more efficient, taking a fast path to declare a node a leaf if its weight
2429
is less than 2 * the minimum. Note that the constructed tree will be

sklearn/grid_search.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,10 @@ def __init__(self, estimator, scoring=None,
387387
def _estimator_type(self):
388388
return self.estimator._estimator_type
389389

390+
@property
391+
def classes_(self):
392+
return self.best_estimator_.classes_
393+
390394
def score(self, X, y=None):
391395
"""Returns the score on the given data, if the estimator has been refit.
392396
@@ -688,7 +692,7 @@ class GridSearchCV(BaseSearchCV):
688692
- An iterable yielding train/test splits.
689693
690694
For integer/None inputs, if the estimator is a classifier and ``y`` is
691-
either binary or multiclass,
695+
either binary or multiclass,
692696
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
693697
other cases, :class:`sklearn.model_selection.KFold` is used.
694698
@@ -900,7 +904,7 @@ class RandomizedSearchCV(BaseSearchCV):
900904
- An iterable yielding train/test splits.
901905
902906
For integer/None inputs, if the estimator is a classifier and ``y`` is
903-
either binary or multiclass,
907+
either binary or multiclass,
904908
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
905909
other cases, :class:`sklearn.model_selection.KFold` is used.
906910

sklearn/tests/test_grid_search.py

Lines changed: 18 additions & 0 deletions
6CFA
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from sklearn.metrics import f1_score
4343
from sklearn.metrics import make_scorer
4444
from sklearn.metrics import roc_auc_score
45+
from sklearn.linear_model import Ridge
4546

4647
from sklearn.exceptions import ChangedBehaviorWarning
4748
from sklearn.exceptions import FitFailedWarning
@@ -785,3 +786,20 @@ def test_parameters_sampler_replacement():
785786
sampler = ParameterSampler(params_distribution, n_iter=7)
786787
samples = list(sampler)
787788
assert_equal(len(samples), 7)
789+
790+
791+
def test_classes__property():
792+
# Test that classes_ property matches best_esimator_.classes_
793+
X = np.arange(100).reshape(10, 10)
794+
y = np.array([0] * 5 + [1] * 5)
795+
Cs = [.1, 1, 10]
796+
797+
grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
798+
grid_search.fit(X, y)
799+
assert_array_equal(grid_search.best_estimator_.classes_,
800+
grid_search.classes_)
801+
802+
# Test that regressors do not have a classes_ attribute
803+
grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
804+
grid_search.fit(X, y)
805+
assert_false(hasattr(grid_search, 'classes_'))

0 commit comments

Comments
 (0)
0