8000 rename cw option to subsample & refactor its implementation · scikit-learn/scikit-learn@35c2535 · GitHub
[go: up one dir, main page]

Skip to content

Commit 35c2535

Browse files
rename cw option to subsample & refactor its implementation
1 parent b541191 commit 35c2535

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

sklearn/ensemble/forest.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,27 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8787
sample_counts = np.bincount(indices, minlength=n_samples)
8888
curr_sample_weight *= sample_counts
8989

90-
if class_weight == 'bootstrap':
90+
if class_weight == 'subsample':
91+
9192
expanded_class_weight = [curr_sample_weight]
93+
9294
for k in range(y.shape[1]):
9395
y_full = y[:, k]
9496
classes_full = np.unique(y_full)
95-
y_boot = y_full[indices]
97+
y_boot = y[indices, k]
9698
classes_boot = np.unique(y_boot)
97-
# Get class weights for the bootstrap sample
98-
weight_k = compute_class_weight('auto', classes_boot, y_boot)
99-
# Expand class weights to cover all classes in original y
100-
# (in case some were missing from the bootstrap sample)
101-
weight_k = np.array([weight_k[np.where(classes_boot == c)][0]
102-
if c in classes_boot
103-
else 0.
104-
for c in classes_full])
99+
100+
# Get class weights for the bootstrap sample, covering all
101+
# classes in case some were missing from the bootstrap sample
102+
weight_k = np.choose(
103+
np.searchsorted(classes_boot, classes_full),
104+
compute_class_weight('auto', classes_boot, y_boot),
105+
mode='clip')
106+
105107
# Expand weights over the original y for this output
106108
weight_k = weight_k[np.searchsorted(classes_full, y_full)]
107109
expanded_class_weight.append(weight_k)
110+
108111
# Multiply all weights by sample & bootstrap weights
109112
curr_sample_weight = np.prod(expanded_class_weight,
110113
axis=0,
@@ -243,7 +246,7 @@ def fit(self, X, y, sample_weight=None):
243246

244247
if expanded_class_weight is not None:
245248
if sample_weight is not None:
246-
sample_weight = np.copy(sample_weight) * expanded_class_weight
249+
sample_weight = sample_weight * expanded_class_weight
247250
else:
248251
sample_weight = expanded_class_weight
249252

@@ -428,14 +431,14 @@ def _validate_y_class_weight(self, y):
428431
self.n_classes_.append(classes_k.shape[0])
429432

430433
if self.class_weight is not None:
431-
valid_presets = ['auto', 'bootstrap']
434+
valid_presets = ('auto', 'subsample')
432435
if isinstance(self.class_weight, six.string_types):
433436
if self.class_weight not in valid_presets:
434437
raise ValueError('Valid presets for class_weight include '
435-
'"auto" and "bootstrap". Given "%s".'
438+
'"auto" and "subsample". Given "%s".'
436439
% self.class_weight)
437440
if self.warm_start:
438-
warn('class_weight presets "auto" or "bootstrap" are '
441+
warn('class_weight presets "auto" or "subsample" are '
439442
'not recommended for warm_start if the fitted data '
440443
'differs from the full dataset. In order to use '
441444
'"auto" weights, use compute_class_weight("auto", '
@@ -453,7 +456,7 @@ def _validate_y_class_weight(self, y):
453456
"in class_weight should match number of "
454457
"outputs.")
455458

456-
if self.class_weight != 'bootstrap' or not self.bootstrap:
459+
if self.class_weight != 'subsample' or not self.bootstrap:
457460
expanded_class_weight = []
458461
for k in range(self.n_outputs_):
459462
if self.class_weight in valid_presets:
@@ -797,7 +800,7 @@ class RandomForestClassifier(ForestClassifier):
797800
and add more estimators to the ensemble, otherwise, just fit a whole
798801
new forest.
799802
800-
class_weight : dict, list of dicts, "auto", "bootstrap" or None, optional
803+
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
801804
802805
Weights associated with classes in the form ``{class_label: weight}``.
803806
If not given, all classes are supposed to have weight one. For
@@ -807,7 +810,7 @@ class RandomForestClassifier(ForestClassifier):
807810
The "auto" mode uses the values of y to automatically adjust
808811
weights inversely proportional to class frequencies in the input data.
809812
810-
The "bootstrap" mode is the same as "auto" except that weights are
813+
The "subsample" mode is the same as "auto" except that weights are
811814
computed based on the bootstrap sample for every tree grown.
812815
813816
For multi-output, the weights of each column of y will be multiplied.
@@ -1127,7 +1130,7 @@ class ExtraTreesClassifier(ForestClassifier):
11271130
and add more estimators to the ensemble, otherwise, just fit a whole
11281131
new forest.
11291132
1130-
class_weight : dict, list of dicts, "auto", "bootstrap" or None, optional
1133+
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
11311134
11321135
Weights associated with classes in the form ``{class_label: weight}``.
11331136
If not given, all classes are supposed to have weight one. For
@@ -1137,7 +1140,7 @@ class ExtraTreesClassifier(ForestClassifier):
11371140
The "auto" mode uses the values of y to automatically adjust
11381141
weights inversely proportional to class frequencies in the input data.
11391142
1140-
The "bootstrap" mode is the same as "auto" except that weights are
1143+
The "subsample" mode is the same as "auto" except that weights are
11411144
computed based on the bootstrap sample for every tree grown.
11421145
11431146
For multi-output, the weights of each column of y will be multiplied.

sklearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def check_class_weight_auto_and_bootstrap_multi_output(name):
802802
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
803803
random_state=0)
804804
clf.fit(X, _y)
805-
clf = ForestClassifier(class_weight='bootstrap', random_state=0)
805+
clf = ForestClassifier(class_weight='subsample', random_state=0)
806806
clf.fit(X, _y)
807807

808808

sklearn/tree/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
279279

280280
if expanded_class_weight is not None:
281281
if sample_weight is not None:
282-
sample_weight = np.copy(sample_weight) * expanded_class_weight
282+
sample_weight = sample_weight * expanded_class_weight
283283
else:
284284
sample_weight = expanded_class_weight
285285

0 commit comments

Comments
 (0)
0