8000 Merge pull request #4190 from trevorstephens/refactor_cw · scikit-learn/scikit-learn@420daac · GitHub
[go: up one dir, main page]

Skip to content

Commit 420daac

Browse files
committed
Merge pull request #4190 from trevorstephens/refactor_cw
MRG+2: Refactor - Farm out class_weight calcs to .utils
2 parents 1d08f08 + f2431db commit 420daac

File tree

6 files changed

+223
-89
lines changed

6 files changed

+223
-89
lines changed

sklearn/ensemble/forest.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5656
from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor,
5757
ExtraTreeClassifier, ExtraTreeRegressor)
5858
from ..tree._tree import DTYPE, DOUBLE
59-
from ..utils import check_random_state, check_array, compute_class_weight
59+
from ..utils import check_random_state, check_array, compute_sample_weight
6060
from ..utils.validation import DataConversionWarning, check_is_fitted
6161
from .base import BaseEnsemble, _partition_estimators
6262
from ..utils.fixes import bincount
@@ -89,30 +89,7 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8989
curr_sample_weight *= sample_counts
9090

9191
if class_weight == 'subsample':
92-
93-
expanded_class_weight = [curr_sample_weight]
94-
95-
for k in range(y.shape[1]):
96-
y_full = y[:, k]
97-
classes_full = np.unique(y_full)
98-
y_boot = y[indices, k]
99-
classes_boot = np.unique(y_boot)
100-
101-
# Get class weights for the bootstrap sample, covering all
102-
# classes in case some were missing from the bootstrap sample
103-
weight_k = np.choose(
104-
np.searchsorted(classes_boot, classes_full),
105-
compute_class_weight('auto', classes_boot, y_boot),
106-
mode='clip')
107-
108-
# Expand weights over the original y for this output
109-
weight_k = weight_k[np.searchsorted(classes_full, y_full)]
110-
expanded_class_weight.append(weight_k)
111-
112-
# Multiply all weights by sample & bootstrap weights
113-
curr_sample_weight = np.prod(expanded_class_weight,
114-
axis=0,
115-
dtype=np.float64)
92+
curr_sample_weight *= compute_sample_weight('auto', y, indices)
11693

11794
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
11895

@@ -449,33 +426,14 @@ def _validate_y_class_weight(self, y):
449426
'properly estimate the class frequency '
450427
'distributions. Pass the resulting weights as the '
451428
'class_weight parameter.')
452-
elif self.n_outputs_ > 1:
453-
if not hasattr(self.class_weight, "__iter__"):
454-
raise ValueError("For multi-output, class_weight should "
455-
"be a list of dicts, or a valid string.")
456-
elif len(self.class_weight) != self.n_outputs_:
457-
raise ValueError("For multi-output, number of elements "
458-
"in class_weight should match number of "
459-
"outputs.")
460429

461430
if self.class_weight != 'subsample' or not self.bootstrap:
462-
expanded_class_weight = []
463-
for k in range(self.n_outputs_):
464-
if self.class_weight in valid_presets:
465-
class_weight_k = 'auto'
466-
elif self.n_outputs_ == 1:
467-
class_weight_k = self.class_weight
468-
else:
469-
class_weight_k = self.class_weight[k]
470-
weight_k = compute_class_weight(class_weight_k,
471-
self.classes_[k],
472-
y_original[:, k])
473-
weight_k = weight_k[np.searchsorted(self.classes_[k],
474-
y_original[:, k])]
475-
expanded_class_weight.append(weight_k)
476-
expanded_class_weight = np.prod(expanded_class_weight,
477-
axis=0,
478-
dtype=np.float64)
431+
if self.class_weight == 'subsample':
432+
class_weight = 'auto'
433+
else:
434+
class_weight = self.class_weight
435+
expanded_class_weight = compute_sample_weight(class_weight,
436+
y_original)
479437

480438
return y, expanded_class_weight
481439

sklearn/linear_model/ridge.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..base import RegressorMixin
2222
from ..utils.extmath import safe_sparse_dot
2323
from ..utils import check_X_y
24-
from ..utils import compute_class_weight
24+
from ..utils import compute_sample_weight, compute_class_weight
2525
from ..utils import column_or_1d
2626
from ..preprocessing import LabelBinarizer
2727
from ..grid_search import GridSearchCV
@@ -597,10 +597,8 @@ def fit(self, X, y):
597597
y = column_or_1d(y, warn=True)
598598

599599
if self.class_weight:
600-
cw = compute_class_weight(self.class_weight,
601-
self.classes_, y)
602600
# get the class weight corresponding to each sample
603-
sample_weight = cw[np.searchsorted(self.classes_, y)]
601+
sample_weight = compute_sample_weight(self.class_weight, y)
604602
else:
605603
sample_weight = None
606604

@@ -1074,10 +1072,9 @@ def fit(self, X, y, sample_weight=None):
10741072
Y = self._label_binarizer.fit_transform(y)
10751073
if not self._label_binarizer.y_type_.startswith('multilabel'):
10761074
y = column_or_1d(y, warn=True)
1077-
cw = compute_class_weight(self.class_weight,
1078-
self.classes_, Y)
10791075
# modify the sample weights with the corresponding class weight
1080-
sample_weight *= cw[np.searchsorted(self.classes_, y)]
1076+
sample_weight = (sample_weight *
1077+
compute_sample_weight(self.class_weight, y))
10811078
_BaseRidgeCV.fit(self, X, Y, sample_weight=sample_weight)
10821079
return self
10831080

sklearn/tree/tree.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
2626
from ..externals import six
2727
from ..feature_selection.from_model import _LearntSelectorMixin
28-
from ..utils import check_array, check_random_state, compute_class_weight
28+
from ..utils import check_array, check_random_state, compute_sample_weight
2929
from ..utils.validation import NotFittedError, check_is_fitted
3030

3131

@@ -172,35 +172,8 @@ def fit(self, X, y, sample_weight=None, check_input=True):
172172
self.n_classes_.append(classes_k.shape[0])
173173

174174
if self.class_weight is not None:
175-
if isinstance(self.class_weight, six.string_types):
176-
if self.class_weight != "auto":
177-
raise ValueError('The only supported preset for '
178-
'class_weight is "auto". Given "%s".'
179-
% self.class_weight)
180-
elif self.n_outputs_ > 1:
181-
if not hasattr(self.class_weight, "__iter__"):
182-
raise ValueError('For multi-output, class_weight '
183-
'should be a list of dicts, or '
184-
'"auto".')
185-
elif len(self.class_weight) != self.n_outputs_:
186-
raise ValueError("For multi-output, number of "
187-
"elements in class_weight should "
188-
"match number of outputs.")
189-
expanded_class_weight = []
190-
for k in range(self.n_outputs_):
191-
if self.n_outputs_ == 1 or self.class_weight == 'auto':
192-
class_weight_k = self.class_weight
193-
else:
194-
class_weight_k = self.class_weight[k]
195-
weight_k = compute_class_weight(class_weight_k,
196-
self.classes_[k],
197-
y_original[:, k])
198-
weight_k = weight_k[np.searchsorted(self.classes_[k],
199-
y_original[:, k])]
200-
expanded_class_weight.append(weight_k)
201-
expanded_class_weight = np.prod(expanded_class_weight,
202-
axis=0,
203-
dtype=np.float64)
175+
expanded_class_weight = compute_sample_weight(
176+
self.class_weight, y_original)
204177

205178
else:
206179
self.classes_ = [None] * self.n_outputs_

sklearn/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
check_random_state, column_or_1d, check_array,
1414
check_consistent_length, check_X_y, indexable,
1515
check_symmetric)
16-
from .class_weight import compute_class_weight
16+
from .class_weight import compute_class_weight, compute_sample_weight
1717
from ..externals.joblib import cpu_count
1818

1919

2020
__all__ = ["murmurhash3_32", "as_float_array",
2121
"assert_all_finite", "check_array",
2222
"warn_if_not_float",
2323
"check_random_state",
24-
"compute_class_weight",
24+
"compute_class_weight", "compute_sample_weight",
2525
"column_or_1d", "safe_indexing",
2626
"check_consistent_length", "check_X_y", 'indexable']
2727

sklearn/utils/class_weight.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# License: BSD 3 clause
44

55
import numpy as np
6+
from ..externals import six
7+
from ..utils.fixes import in1d
68

79
from .fixes import bincount
810

@@ -61,3 +63,104 @@ def compute_class_weight(class_weight, classes, y):
6163
weight[i] = class_weight[c]
6264

6365
return weight
66+
67+
68+
def compute_sample_weight(class_weight, y, indices=None):
69+
"""Estimate sample weights by class for unbalanced datasets.
70+
71+
Parameters
72+
----------
73+
class_weight : dict, list of dicts, "auto", or None, optional
74+
Weights associated with classes in the form ``{class_label: weight}``.
75+
If not given, all classes are supposed to have weight one. For
76+
multi-output problems, a list of dicts can be provided in the same
77+
order as the columns of y.
78+
79+
The "auto" mode uses the values of y to automatically adjust
80+
weights inversely proportional to class frequencies in the input data.
81+
82+
For multi-output, the weights of each column of y will be multiplied.
83+
84+
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
85+
Array of original class labels per sample.
86+
87+
indices : array-like, shape (n_subsample,), or None
88+
Array of indices to be used in a subsample. Can be of length less than
89+
n_samples in the case of a subsample, or equal to n_samples in the
90+
case of a bootstrap subsample with repeated indices. If None, the
91+
sample weight will be calculated over the full sample. Only "auto" is
92+
supported for class_weight if this is provided.
93+
94+
Returns
95+
-------
96+
sample_weight_vect : ndarray, shape (n_samples,)
97+
Array with sample weights as applied to the original y
98+
"""
99+
100+
y = np.atleast_1d(y)
101+
if y.ndim == 1:
102+
y = np.reshape(y, (-1, 1))
103+
n_outputs = y.shape[1]
104+
105+
if isinstance(class_weight, six.string_types):
106+
if class_weight != 'auto':
107+
raise ValueError('The only valid preset for class_weight is '
108+
'"auto". Given "%s".' % class_weight)
109+
elif (indices is not None and
110+
not isinstance(class_weight, six.string_types)):
111+
raise ValueError('The only valid class_weight for subsampling is '
112+
'"auto". Given "%s".' % class_weight)
113+
elif n_outputs > 1:
114+
if (not hasattr(class_weight, "__iter__") or
115+
isinstance(class_weight, dict)):
116+
raise ValueError("For multi-output, class_weight should be a "
117+
"list of dicts, or a valid string.")
118+
if len(class_weight) != n_outputs:
119+
raise ValueError("For multi-output, number of elements in "
120+
"class_weight should match number of outputs.")
121+
122+
expanded_class_weight = []
123+
for k in range(n_outputs):
124+
125+
y_full = y[:, k]
126+
classes_full = np.unique(y_full)
127+
classes_missing = None
128+
129+
if class_weight == 'auto' or n_outputs == 1:
130+
class_weight_k = class_weight
131+
else:
132+
class_weight_k = class_weight[k]
133+
134+
if indices is not None:
135+
# Get class weights for the subsample, covering all classes in
136+
# case some labels that were present in the original data are
137+
# missing from the sample.
138+
y_subsample = y[indices, k]
139+
classes_subsample = np.unique(y_subsample)
140+
141+
weight_k = np.choose(np.searchsorted(classes_subsample,
142+
classes_full),
143+
compute_class_weight(class_weight_k,
144+
classes_subsample,
145+
y_subsample),
146+
mode='clip')
147+
148+
classes_missing = set(classes_full) - set(classes_subsample)
149+
else:
150+
weight_k = compute_class_weight(class_weight_k,
151+
classes_full,
152+
y_full)
153+
154+
weight_k = weight_k[np.searchsorted(classes_full, y_full)]
155+
156+
if classes_missing:
157+
# Make missing classes' weight zero
158+
weight_k[in1d(y_full, list(classes_missing))] = 0.
159+
160+
expanded_class_weight.append(weight_k)
161+
162+
expanded_class_weight = np.prod(expanded_class_weight,
163+
axis=0,
164+
dtype=np.float64)
165+
166+
return expanded_class_weight

0 commit comments

Comments
 (0)
0