8000 base: Add code to find unique classes_ given classes or y. · erg/scikit-learn@3edcb15 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3edcb15

Browse files
committed
base: Add code to find unique classes_ given classes or y.
1 parent 56afa38 commit 3edcb15

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

sklearn/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from scipy import sparse
99

1010
from .metrics import r2_score
11+
from .utils.fixes import unique
1112

1213

1314
###############################################################################
@@ -261,6 +262,24 @@ def __str__(self):
261262
class ClassifierMixin(object):
262263
"""Mixin class for all classifiers in scikit-learn"""
263264

265+
def _check_classes(self, classes):
266+
print classes
267+
"""Common error checking for the prepare functions below."""
268+
if len(classes) is 0:
269+
raise ValueError("no output classes")
270+
if len(classes) != len(set(classes)):
271+
raise ValueError("duplicate class label")
272+
273+
def _prepare_classes(self, y):
274+
"""Set self.classes and self.y_inverse_"""
275+
if self.classes is None:
276+
self.classes_, self.y_inverse_ = unique(y, return_inverse=True)
277+
else:
278+
self._check_classes(self.classes)
279+
self.classes_ = self.classes
280+
assoc = {v:k for k,v in enumerate(y)}
281+
self.y_inverse_ = np.array([assoc[k] for k in y])
282+
264283
def score(self, X, y):
265284
"""Returns the mean accuracy on the given test data and labels.
266285

0 commit comments

Comments
 (0)
0