Description
Description
Using a GridSearchCV
(or RandomizedSearchCV
) as the estimator for cross_val_predict
with method='predict_proba'
produces an AttributeError
.
PR #7889 modified sklearn.model_selection._validation._fit_and_predict
so that it used the estimator's classes_
attribute to ensure ordering of the output columns. The sklearn.model_selection._search.BaseSearchCV
subclasses do not expose their estimator's classes_
attribute.
Suggested fix: Add the method
@property
def classes_(self):
return self.best_estimator_.classes_
to the sklearn.model_selection._search.BaseSearchCV
class. Possibly with a check to see if best_estimator_
exists first, so that it can raise a more informative error before the grid searcher is fit? Maybe also check self.estimator.classes_
?
Steps/Code to Reproduce
from sklearn.model_selection import GridSearchCV, cross_val_predict
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
est = GridSearchCV(LogisticRegression(), {'C': [1, 10]})
X, y = make_classification()
cross_val_predict(est, X, y, method='predict_proba')
Expected Results
No error is raised. The output when this code is run at tag 0.18.1
is an array of shape (100, 2).
Actual Results
Run at commit 13cc121 (master branch as of this issue filing):
AttributeError Traceback (most recent call last)
<ipython-input-3-09bc2b6d94ad> in <module>()
4
5 est = GridSearchCV(LogisticRegression(), {'C': [1, 10]})
----> 6 cross_val_predict(est, X, y, method='predict_proba')
/Users/shoover/src/scikit-learn/sklearn/model_selection/_validation.py in cross_val_predict(estimator, X, y, groups, cv, n_jobs, verbose, fit_params, pre_dispatch, method)
404 prediction_blocks = parallel(delayed(_fit_and_predict)(
405 clone(estimator), X, y, train, test, verbose, fit_params, method)
--> 406 for train, test in cv.split(X, y, groups))
407
408 # Concatenate the predictions
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
756 # was dispatched. In particular this covers the edge
757 # case of Parallel used with an exhausted iterator.
--> 758 while self.dispatch_one_batch(iterator):
759 self._iterating = True
760 else:
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/parallel.py in dispatch_one_batch(self, iterator)
606 return False
607 else:
--> 608 self._dispatch(tasks)
609 return True
610
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/parallel.py in _dispatch(self, batch)
569 dispatch_timestamp = time.time()
570 cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
--> 571 job = self._backend.apply_async(batch, callback=cb)
572 self._jobs.append(job)
573
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/_parallel_backends.py in apply_async(self, func, callback)
107 def apply_async(self, func, callback=None):
108 """Schedule a func to be run"""
--> 109 result = ImmediateResult(func)
110 if callback:
111 callback(result)
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/_parallel_backends.py in __init__(self, batch)
324 # Don't delay the application, to avoid keeping the input
325 # arguments in memory
--> 326 self.results = batch()
327
328 def get(self):
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/parallel.py in __call__(self)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
132
133 def __len__(self):
/Users/shoover/src/scikit-learn/sklearn/externals/joblib/parallel.py in <listcomp>(.0)
129
130 def __call__(self):
--> 131 return [func(*args, **kwargs) for func, args, kwargs in self.items]
132
133 def __len__(self):
/Users/shoover/src/scikit-learn/sklearn/model_selection/_validation.py in _fit_and_predict(estimator, X, y, train, test, verbose, fit_params, method)
486 predictions_[:, estimator.classes_[-1]] = predictions
487 else:
--> 488 predictions_[:, estimator.classes_] = predictions
489 predictions = predictions_
490 return predictions, test
AttributeError: 'GridSearchCV' object has no attribute 'classes_'
Versions
Darwin-15.6.0-x86_64-i386-64bit
Python 3.5.2 | packaged by conda-forge | (default, Jul 26 2016, 01:37:38)
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.54)]
NumPy 1.11.2
SciPy 0.17.0
Scikit-Learn 0.19.dev0