File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change 88from scipy import sparse
99
1010from .metrics import r2_score
11+ from .utils .fixes import unique
1112
1213
1314###############################################################################
@@ -261,6 +262,24 @@ def __str__(self):
261262class ClassifierMixin (object ):
262263 """Mixin class for all classifiers in scikit-learn"""
263264
265+ def _check_classes (self , classes ):
266+ print classes
267+ """Common error checking for the prepare functions below."""
268+ if len (classes ) is 0 :
269+ raise ValueError ("no output classes" )
270+ if len (classes ) != len (set (classes )):
271+ raise ValueError ("duplicate class label" )
272+
273+ def _prepare_classes (self , y ):
274+ """Set self.classes and self.y_inverse_"""
275+ if self .classes is None :
276+ self .classes_ , self .y_inverse_ = unique (y , return_inverse = True )
277+ else :
278+ self ._check_classes (self .classes )
279+ self .classes_ = self .classes
280+ assoc = {v :k for k ,v in enumerate (y )}
281+ self .y_inverse_ = np .array ([assoc [k ] for k in y ])
282+
264283 def score (self , X , y ):
265284 """Returns the mean accuracy on the given test data and labels.
266285
You can’t perform that action at this time.
0 commit comments