@@ -58,7 +58,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
58
58
from ..tree import (DecisionTreeClassifier , DecisionTreeRegressor ,
59
59
ExtraTreeClassifier , ExtraTreeRegressor )
60
60
from ..tree ._tree import DTYPE , DOUBLE
61
- from ..utils import check_random_state , check_array
61
+ from ..utils import check_random_state , check_array , compute_class_weight
62
62
from ..utils .validation import DataConversionWarning
63
63
from .base import BaseEnsemble , _partition_estimators
64
64
@@ -122,7 +122,8 @@ def __init__(self,
122
122
n_jobs = 1 ,
123
123
random_state = None ,
124
124
verbose = 0 ,
125
- warm_start = False ):
125
+ warm_start = False ,
126
+ class_weight = None ):
126
127
super (BaseForest , self ).__init__ (
127
128
base_estimator = base_estimator ,
128
129
n_estimators = n_estimators ,
@@ -134,6 +135,7 @@ def __init__(self,
134
135
self .random_state = random_state
135
136
self .verbose = verbose
136
137
self .warm_start = warm_start
138
+ self .class_weight = class_weight
137
139
138
140
def apply (self , X ):
139
141
"""Apply trees in the forest to X, return leaf indices.
@@ -211,11 +213,17 @@ def fit(self, X, y, sample_weight=None):
211
213
212
214
self .n_outputs_ = y .shape [1 ]
213
215
214
- y = self ._validate_y (y )
216
+ y , cw = self ._validate_y_cw (y )
215
217
216
218
if getattr (y , "dtype" , None ) != DOUBLE or not y .flags .contiguous :
217
219
y = np .ascontiguousarray (y , dtype = DOUBLE )
218
220
221
+ if cw is not None :
222
+ if sample_weight is not None :
223
+ sample_weight *= cw
224
+ else :
225
+ sample_weight = cw
226
+
219
227
# Check parameters
220
228
self ._validate_estimator ()
221
229
@@ -279,9 +287,9 @@ def fit(self, X, y, sample_weight=None):
279
287
def _set_oob_score (self , X , y ):
280
288
"""Calculate out of bag predictions and score."""
281
289
282
- def _validate_y (self , y ):
290
+ def _validate_y_cw (self , y ):
283
291
# Default implementation
284
- return y
292
+ return y , None
285
293
286
294
@property
287
295
def feature_importances_ (self ):
@@ -320,7 +328,8 @@ def __init__(self,
320
328
n_jobs = 1 ,
321
329
random_state = None ,
322
330
verbose = 0 ,
323
- warm_start = False ):
331
+ warm_start = False ,
332
+ class_weight = None ):
324
333
325
334
super (ForestClassifier , self ).__init__ (
326
335
base_estimator ,
@@ -331,7 +340,8 @@ def __init__(self,
331
340
n_jobs = n_jobs ,
332
341
random_state = random_state ,
333
342
verbose = verbose ,
334
- warm_start = warm_start )
343
+ warm_start = warm_start ,
344
+ class_weight = class_weight )
335
345
336
346
def _set_oob_score (self , X , y ):
337
347
"""Compute out-of-bag score"""
@@ -377,8 +387,9 @@ def _set_oob_score(self, X, y):
377
387
378
388
self .oob_score_ = oob_score / self .n_outputs_
379
389
380
- def _validate_y (self , y ):
381
- y = np .copy (y )
390
+ def _validate_y_cw (self , y_org ):
391
+ y = np .copy (y_org )
392
+ cw = None
382
393
383
394
self .classes_ = []
384
395
self .n_classes_ = []
@@ -388,7 +399,19 @@ def _validate_y(self, y):
388
399
self .classes_ .append (classes_k )
389
400
self .n_classes_ .append (classes_k .shape [0 ])
390
401
391
- return y
402
+ if self .class_weight is not None :
403
+ if self .n_outputs_ == 1 :
404
+ cw = compute_class_weight (self .class_weight ,
405
+ self .classes_ [0 ],
406
+ y_org [:, 0 ])
407
+ cw = cw [np .searchsorted (self .classes_ [0 ], y_org [:, 0 ])]
408
+ else :
409
+ raise NotImplementedError ('class_weights are not supported '
410
+ 'for multi-output. You may use '
411
+ 'sample_weights in the fit method '
412
+ 'to weight by sample.' )
413
+
414
+ return y , cw
392
415
393
416
def predict (self , X ):
394
417
"""Predict class for X.
@@ -707,6 +730,18 @@ class RandomForestClassifier(ForestClassifier):
707
730
and add more estimators to the ensemble, otherwise, just fit a whole
708
731
new forest.
709
732
733
+ class_weight : dict, {class_label: weight} or "auto" or None, optional
734
+ Weights associated with classes. If not given, all classes
735
+ are supposed to have weight one.
736
+
737
+ The "auto" mode uses the values of y to automatically adjust
738
+ weights inversely proportional to class frequencies.
739
+
740
+ Note that this is only supported for single-output classification.
741
+
742
+ Note that these weights will be multiplied with class_weight (passed
743
+ through the fit method) if sample_weight is specified
744
+
710
745
Attributes
711
746
----------
712
747
estimators_ : list of DecisionTreeClassifier
@@ -755,7 +790,8 @@ def __init__(self,
755
790
n_jobs = 1 ,
756
791
random_state = None ,
757
792
verbose = 0 ,
758
- warm_start = False ):
793
+ warm_start = False ,
794
+ class_weight = None ):
759
795
super (RandomForestClassifier , self ).__init__ (
760
796
base_estimator = DecisionTreeClassifier (),
761
797
n_estimators = n_estimators ,
@@ -768,7 +804,8 @@ def __init__(self,
768
804
n_jobs = n_jobs ,
769
805
random_state = random_state ,
770
806
verbose = verbose ,
771
- warm_start = warm_start )
807
+ warm_start = warm_start ,
808
+ class_weight = class_weight )
772
809
773
810
self .criterion = criterion
774
811
self .max_depth = max_depth
@@ -1017,6 +1054,18 @@ class ExtraTreesClassifier(ForestClassifier):
1017
1054
and add more estimators to the ensemble, otherwise, just fit a whole
1018
1055
new forest.
1019
1056
1057
+ class_weight : dict, {class_label: weight} or "auto" or None, optional
1058
+ Weights associated with classes. If not given, all classes
1059
+ are supposed to have weight one.
1060
+
1061
+ The "auto" mode uses the values of y to automatically adjust
1062
+ weights inversely proportional to class frequencies.
1063
+
1064
+ Note that this is only supported for single-output classification.
1065
+
1066
+ Note that these weights will be multiplied with class_weight (passed
1067
+ through the fit method) if sample_weight is specified
1068
+
1020
1069
Attributes
1021
1070
----------
1022
1071
estimators_ : list of DecisionTreeClassifier
@@ -1068,7 +1117,8 @@ def __init__(self,
1068
1117
n_jobs = 1 ,
1069
1118
random_state = None ,
1070
1119
verbose = 0 ,
1071
- warm_start = False ):
1120
+ warm_start = False ,
1121
+ class_weight = None ):
1072
1122
super (ExtraTreesClassifier , self ).__init__ (
1073
1123
base_estimator = ExtraTreeClassifier (),
1074
1124
n_estimators = n_estimators ,
@@ -1080,7 +1130,8 @@ def __init__(self,
1080
1130
n_jobs = n_jobs ,
1081
1131
random_state = random_state ,
1082
1132
verbose = verbose ,
1083
- warm_start = warm_start )
1133
+ warm_start = warm_start ,
1134
+ class_weight = class_weight )
1084
1135
1085
1136
self .criterion = criterion
1086
1137
self .max_depth = max_depth
0 commit comments