8000 FIX add class_weight="balanced_subsample" to the forests to keep back… · scikit-learn/scikit-learn@df3d97b · GitHub
[go: up one dir, main page]

Skip to content

Commit df3d97b

Browse files
committed
FIX add class_weight="balanced_subsample" to the forests to keep backward compatibility to 0.16
1 parent 523addf commit df3d97b

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

sklearn/ensemble/forest.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class calls the ``fit`` method of each sub-estimator on random samples
4141

4242
from __future__ import division
4343

44+
import warnings
4445
from warnings import warn
46+
4547
from abc import ABCMeta, abstractmethod
4648

4749
import numpy as np
@@ -89,6 +91,10 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8991
curr_sample_weight *= sample_counts
9092

9193
if class_weight == 'subsample':
94+
with warnings.catch_warnings():
95+
warnings.simplefilter('ignore', DeprecationWarning)
96+
curr_sample_weight *= compute_sample_weight('auto', y, indices)
97+
elif class_weight == 'balanced_subsample':
9298
curr_sample_weight *= compute_sample_weight('balanced', y, indices)
9399

94100
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
@@ -414,30 +420,40 @@ def _validate_y_class_weight(self, y):
414420
self.n_classes_.append(classes_k.shape[0])
415421

416422
if self.class_weight is not None:
417-
valid_presets = ('auto', 'balanced', 'subsample', 'auto')
423+
valid_presets = ('auto', 'balanced', 'balanced_subsample', 'subsample', 'auto')
418424
if isinstance(self.class_weight, six.string_types):
419425
if self.class_weight not in valid_presets:
420426
raise ValueError('Valid presets for class_weight include '
421-
'"balanced" and "subsample". Given "%s".'
427+
'"balanced" and "balanced_subsample". Given "%s".'
422428
% self.class_weight)
429+
if self.class_weight == "subsample":
430+
warn("class_weight='subsample' is deprecated and will be removed in 0.18."
431+
" It was replaced by class_weight='balanced_subsample' "
432+
"using the balanced strategy.", DeprecationWarning)
423433
if self.warm_start:
424-
warn('class_weight presets "balanced" or "subsample" are '
434+
warn('class_weight presets "balanced" or "balanced_subsample" are '
425435
'not recommended for warm_start if the fitted data '
426436
'differs from the full dataset. In order to use '
427-
'"auto" weights, use compute_class_weight("balanced", '
437+
'"balanced" weights, use compute_class_weight("balanced", '
428438
'classes, y). In place of y you can use a large '
429439
'enough sample of the full training set target to '
430440
'properly estimate the class frequency '
431441
'distributions. Pass the resulting weights as the '
432442
'class_weight parameter.')
433443

434-
if self.class_weight != 'subsample' or not self.bootstrap:
444+
if (self.class_weight not in ['subsample', 'balanced_subsample'] or
445+
not self.bootstrap):
435446
if self.class_weight == 'subsample':
436-
class_weight = 'balanced'
447+
class_weight = 'auto'
448+
elif self.class_weight == "balanced_subsample":
449+
class_weight = "balanced"
437450
else:
438451
class_weight = self.class_weight
439-
expanded_class_weight = compute_sample_weight(class_weight,
440-
y_original)
452+
with warnings.catch_warnings():
453+
if class_weight == "auto":
454+
warnings.simplefilter('ignore', DeprecationWarning)
455+
expanded_class_weight = compute_sample_weight(class_weight,
456+
y_original)
441457

442458
return y, expanded_class_weight
443459

@@ -758,7 +774,7 @@ class RandomForestClassifier(ForestClassifier):
758774
and add more estimators to the ensemble, otherwise, just fit a whole
759775
new forest.
760776
761-
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
777+
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
762778
763779
Weights associated with classes in the form ``{class_label: weight}``.
764780
If not given, all classes are supposed to have weight one. For
@@ -769,7 +785,7 @@ class RandomForestClassifier(ForestClassifier):
769785
weights inversely proportional to class frequencies in the input data
770786
as ``n_samples / (n_classes * np.bincount(y))``
771787
772-
The "subsample" mode is the same as "balanced" except that weights are
788+
The "balanced_subsample" mode is the same as "balanced" except that weights are
773789
computed based on the bootstrap sample for every tree grown.
774790
775791
For multi-output, the weights of each column of y will be multiplied.
@@ -1101,7 +1117,7 @@ class ExtraTreesClassifier(ForestClassifier):
11011117
and add more estimators to the ensemble, otherwise, just fit a whole
11021118
new forest.
11031119
1104-
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
1120+
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
11051121
11061122
Weights associated with classes in the form ``{class_label: weight}``.
11071123
If not given, all classes are supposed to have weight one. For
@@ -1112,7 +1128,7 @@ class ExtraTreesClassifier(ForestClassifier):
11121128
weights inversely proportional to class frequencies in the input data
11131129
as ``n_samples / (n_classes * np.bincount(y))``
11141130
1115-
The "subsample" mode is the same as "balanced" except that weights are
1131+
The "balanced_subsample" mode is the same as "balanced" except that weights are
11161132
computed based on the bootstrap sample for every tree grown.
11171133
11181134
For multi-output, the weights of each column of y will be multiplied.

sklearn/ensemble/tests/test_forest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.utils.testing import assert_greater_equal
2525
from sklearn.utils.testing import assert_raises
2626
from sklearn.utils.testing import assert_warns
27+
from sklearn.utils.testing import assert_warns_message
2728
from sklearn.utils.testing import ignore_warnings
2829

2930
from sklearn import datasets
@@ -802,7 +803,11 @@ def check_class_weight_balanced_and_bootstrap_multi_output(name):
802803
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
803804
random_state=0)
804805
clf.fit(X, _y)
806+
# smoke test for subsample and balanced subsample
807+
clf = ForestClassifier(class_weight='balanced_subsample', random_state=0)
808+
clf.fit(X, _y)
805809
clf = ForestClassifier(class_weight='subsample', random_state=0)
810+
#assert_warns_message(DeprecationWarning, "balanced_subsample", clf.fit, X, _y)
806811
clf.fit(X, _y)
807812

808813

0 commit comments

Comments
 (0)
0