8000 add support for class_weights · scikit-learn/scikit-learn@de35bea · GitHub
[go: up one dir, main page]

Skip to content

Commit de35bea

Browse files
add support for class_weights
1 parent 6dab7c5 commit de35bea

File tree

3 files changed

+114
-14
lines changed

3 files changed

+114
-14
lines changed

sklearn/ensemble/forest.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5858
from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor,
5959
ExtraTreeClassifier, ExtraTreeRegressor)
6060
from ..tree._tree import DTYPE, DOUBLE
61-
from ..utils import check_random_state, check_array
61+
from ..utils import check_random_state, check_array, compute_class_weight
6262
from ..utils.validation import DataConversionWarning
6363
from .base import BaseEnsemble, _partition_estimators
6464

@@ -122,7 +122,8 @@ def __init__(self,
122122
n_jobs=1,
123123
random_state=None,
124124
verbose=0,
125-
warm_start=False):
125+
warm_start=False,
126+
class_weight=None):
126127
super(BaseForest, self).__init__(
127128
base_estimator=base_estimator,
128129
n_estimators=n_estimators,
@@ -134,6 +135,7 @@ def __init__(self,
134135
self.random_state = random_state
135136
self.verbose = verbose
136137
self.warm_start = warm_start
138+
self.class_weight = class_weight
137139

138140
def apply(self, X):
139141
"""Apply trees in the forest to X, return leaf indices.
@@ -211,11 +213,17 @@ def fit(self, X, y, sample_weight=None):
211213

212214
self.n_outputs_ = y.shape[1]
213215

214-
y = self._validate_y(y)
216+
y, cw = self._validate_y_cw(y)
215217

216218
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
217219
y = np.ascontiguousarray(y, dtype=DOUBLE)
218220

221+
if cw is not None:
222+
if sample_weight is not None:
223+
sample_weight *= cw
224+
else:
225+
sample_weight = cw
226+
219227
# Check parameters
220228
self._validate_estimator()
221229

@@ -279,9 +287,9 @@ def fit(self, X, y, sample_weight=None):
279287
def _set_oob_score(self, X, y):
280288
"""Calculate out of bag predictions and score."""
281289

282-
def _validate_y(self, y):
290+
def _validate_y_cw(self, y):
283291
# Default implementation
284-
return y
292+
return y, None
285293

286294
@property
287295
def feature_importances_(self):
@@ -320,7 +328,8 @@ def __init__(self,
320328
n_jobs=1,
321329
random_state=None,
322330
verbose=0,
323-
warm_start=False):
331+
warm_start=False,
332+
class_weight=None):
324333

325334
super(ForestClassifier, self).__init__(
326335
base_estimator,
@@ -331,7 +340,8 @@ def __init__(self,
331340
n_jobs=n_jobs,
332341
random_state=random_state,
333342
verbose=verbose,
334-
warm_start=warm_start)
343+
warm_start=warm_start,
344+
class_weight=class_weight)
335345

336346
def _set_oob_score(self, X, y):
337347
"""Compute out-of-bag score"""
@@ -377,8 +387,9 @@ def _set_oob_score(self, X, y):
377387

378388
self.oob_score_ = oob_score / self.n_outputs_
379389

380-
def _validate_y(self, y):
381-
y = np.copy(y)
390+
def _validate_y_cw(self, y_org):
391+
y = np.copy(y_org)
392+
cw = None
382393

383394
self.classes_ = []
384395
self.n_classes_ = []
@@ -388,7 +399,19 @@ def _validate_y(self, y):
388399
self.classes_.append(classes_k)
389400
self.n_classes_.append(classes_k.shape[0])
390401

391-
return y
402+
if self.class_weight is not None:
403+
if self.n_outputs_ == 1:
404+
cw = compute_class_weight(self.class_weight,
405+
self.classes_[0],
406+
y_org[:, 0])
407+
cw = cw[np.searchsorted(self.classes_[0], y_org[:, 0])]
408+
else:
409+
raise NotImplementedError('class_weights are not supported '
410+
'for multi-output. You may use '
411+
'sample_weights in the fit method '
412+
'to weight by sample.')
413+
414+
return y, cw
392415

393416
def predict(self, X):
394417
"""Predict class for X.
@@ -707,6 +730,18 @@ class RandomForestClassifier(ForestClassifier):
707730
and add more estimators to the ensemble, otherwise, just fit a whole
708731
new forest.
709732
733+
class_weight : dict, {class_label: weight} or "auto" or None, optional
734+
Weights associated with classes. If not given, all classes
735+
are supposed to have weight one.
736+
737+
The "auto" mode uses the values of y to automatically adjust
738+
weights inversely proportional to class frequencies.
739+
740+
Note that this is only supported for single-output classification.
741+
742+
Note that these weights will be multiplied with class_weight (passed
743+
through the fit method) if sample_weight is specified
744+
710745
Attributes
711746
----------
712747
estimators_ : list of DecisionTreeClassifier
@@ -755,7 +790,8 @@ def __init__(self,
755790
n_jobs=1,
756791
random_state=None,
757792
verbose=0,
758-
warm_start=False):
793+
warm_start=False,
794+
class_weight=None):
759795
super(RandomForestClassifier, self).__init__(
760796
base_estimator=DecisionTreeClassifier(),
761797
n_estimators=n_estimators,
@@ -768,7 +804,8 @@ def __init__(self,
768804
n_jobs=n_jobs,
769805
random_state=random_state,
770806
verbose=verbose,
771-
warm_start=warm_start)
807+
warm_start=warm_start,
808+
class_weight=class_weight)
772809

773810
self.criterion = criterion
774811
self.max_depth = max_depth
@@ -1017,6 +1054,18 @@ class ExtraTreesClassifier(ForestClassifier):
10171054
and add more estimators to the ensemble, otherwise, just fit a whole
10181055
new forest.
10191056
1057+
class_weight : dict, {class_label: weight} or "auto" or None, optional
1058+
Weights associated with classes. If not given, all classes
1059+
are supposed to have weight one.
1060+
1061+
The "auto" mode uses the values of y to automatically adjust
1062+
weights inversely proportional to class frequencies.
1063+
1064+
Note that this is only supported for single-output classification.
1065+
1066+
Note that these weights will be multiplied with class_weight (passed
1067+
through the fit method) if sample_weight is specified
1068+
10201069
Attributes
10211070
----------
10221071
estimators_ : list of DecisionTreeClassifier
@@ -1068,7 +1117,8 @@ def __init__(self,
10681117
n_jobs=1,
10691118
random_state=None,
10701119
verbose=0,
1071-
warm_start=False):
1120+
warm_start=False,
1121+
class_weight=None):
10721122
super(ExtraTreesClassifier, self).__init__(
10731123
base_estimator=ExtraTreeClassifier(),
10741124
n_estimators=n_estimators,
@@ -1080,7 +1130,8 @@ def __init__(self,
10801130
n_jobs=n_jobs,
10811131
random_state=random_state,
10821132
verbose=verbose,
1083-
warm_start=warm_start)
1133+
warm_start=warm_start,
1134+
class_weight=class_weight)
10841135

10851136
self.criterion = criterion
10861137
self.max_depth = max_depth

sklearn/ensemble/tests/test_forest.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,53 @@ def test_1d_input():
747747
yield check_1d_input, name, X, X_2d, y
748748

749749

750+
def check_class_weights(name):
751+
"""Check class_weights resemble sample_weights behavior."""
752+
ForestClassifier = FOREST_CLASSIFIERS[name]
753+
754+
# Iris is balanced, so no effect expected for using 'auto' weights
755+
clf1 = ForestClassifier(random_state=0)
756+
clf1.fit(iris.data, iris.target)
757+
clf2 = ForestClassifier(class_weight='auto', random_state=0)
758+
clf2.fit(iris.data, iris.target)
759+
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
760+
761+
# Inflate importance of class 1, check against user-defined weights
762+
sample_weight = np.ones(iris.target.shape)
763+
sample_weight[iris.target == 1] *= 100
764+
class_weight = {0: 1., 1: 100., 2: 1.}
765+
clf1 = ForestClassifier(random_state=0)
766+
clf1.fit(iris.data, iris.target, sample_weight)
767+
clf2 = ForestClassifier(class_weight=class_weight, random_state=0)
768+
clf2.fit(iris.data, iris.target)
769+
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
770+
771+
# Check that sample_weight and class_weight are multiplicative
772+
clf1 = ForestClassifier(random_state=0)
773+
clf1.fit(iris.data, iris.target, sample_weight**2)
774+
clf2 = ForestClassifier(class_weight=class_weight, random_state=0)
775+
clf2.fit(iris.data, iris.target, sample_weight)
776+
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
777+
778+
779+
def test_class_weights():
780+
for name in FOREST_CLASSIFIERS:
781+
yield check_class_weights, name
782+
783+
784+
def check_class_weight_failure_multi_output(name):
785+
"""Test class_weight failure for multi-output"""
786+
ForestClassifier = FOREST_CLASSIFIERS[name]
787+
_y = np.vstack((y, np.array(y) * 2)).T
788+
clf = ForestClassifier(class_weight='auto')
789+
assert_raises(NotImplementedError, clf.fit, X, _y)
790+
791+
792+
def test_class_weight_failure_multi_output():
793+
for name in FOREST_CLASSIFIERS:
794+
yield check_class_weight_failure_multi_output, name
795+
796+
750797
def check_warm_start(name, random_state=42):
751798
"""Test if fitting incrementally with warm start gives a forest of the
752799
right size and the same results as a normal fit."""

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,8 @@ def check_class_weight_classifiers(name, Classifier):
737737
classifier = Classifier(class_weight=class_weight)
738738
if hasattr(classifier, "n_iter"):
739739
classifier.set_params(n_iter=100)
740+
if hasattr(classifier, "min_weight_fraction_leaf"):
741+
classifier.set_params(min_weight_fraction_leaf=0.01)
740742

741743
set_random_state(classifier)
742744
classifier.fit(X_train, y_train)

0 commit comments

Comments
 (0)
0