10000 Merge pull request #2566 from Manoj-Kumar-S/constant-output · r2k0/scikit-learn@bf033bf · GitHub
[go: up one dir, main page]

Skip to content

Commit bf033bf

Browse files
committed
Merge pull request scikit-learn#2566 from Manoj-Kumar-S/constant-output
Constant output dummy classifier.
2 parents 623dfef + 23331a9 commit bf033bf

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

doc/modules/model_evaluation.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,9 @@ implements three such simple strategies for classification:
10731073
set's class distribution,
10741074
- `most_frequent` always predicts the most frequent label in the training set,
10751075
- `uniform` generates predictions uniformly at random.
1076+
- `constant` always predicts a constant label that is provided by the user.
1077+
A major motivation of this method is F1-scoring when the positive class
1078+
is in the minority.
10761079

10771080
Note that with all these strategies, the `predict` method completely ignores
10781081
the input data!
@@ -1096,7 +1099,7 @@ Next, let's compare the accuracy of `SVC` and `most_frequent`::
10961099
0.63...
10971100
>>> clf = DummyClassifier(strategy='most_frequent',random_state=0)
10981101
>>> clf.fit(X_train, y_train)
1099-
DummyClassifier(random_state=0, strategy='most_frequent')
1102+
DummyClassifier(constant=None, random_state=0, strategy='most_frequent')
11001103
>>> clf.score(X_test, y_test) # doctest: +ELLIPSIS
11011104
0.57...
11021105

sklearn/dummy.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@ class DummyClassifier(BaseEstimator, ClassifierMixin):
2828
* "most_frequent": always predicts the most frequent label in the
2929
training set.
3030
* "uniform": generates predictions uniformly at random.
31+
* "constant": always predicts a constant label that is provided by
32+
the user. This is useful for metrics that evaluate a non-majority
33+
class
3134
3235
random_state: int seed, RandomState instance, or None (default)
3336
The seed of the pseudo random number generator to use.
3437
38+
constant: int or str or array of shape = [n_outputs]
39+
The explicit constant as predicted by the "constant" strategy. This
40+
parameter is useful only for the "constant" strategy.
41+
3542
Attributes
3643
----------
3744
`classes_` : array or list of array of shape = [n_classes]
@@ -48,11 +55,14 @@ class DummyClassifier(BaseEstimator, ClassifierMixin):
4855
4956
`outputs_2d_` : bool,
5057
True if the output at fit is 2d, else false.
58+
5159
"""
5260

53-
def __init__(self, strategy="stratified", random_state=None):
61+
def __init__(self, strategy="stratified", random_state=None,
62+
constant=None):
5463
self.strategy = strategy
5564
self.random_state = random_state
65+
self.constant = constant
5666

5767
def fit(self, X, y):
5868
"""Fit the random classifier.
@@ -71,7 +81,8 @@ def fit(self, X, y):
7181
self : object
7282
Returns self.
7383
"""
74-
if self.strategy not in ("most_frequent", "stratified", "uniform"):
84+
if self.strategy not in ("most_frequent", "stratified", "uniform",
85+
"constant"):
7586
raise ValueError("Unknown strategy type.")
7687

7788
y = np.atleast_1d(y)
@@ -85,12 +96,29 @@ def fit(self, X, y):
8596
self.n_classes_ = []
8697
self.class_prior_ = []
8798

99+
if self.strategy == "constant":
100+
if self.constant is None:
101+
raise ValueError("Constant target value has to be specified "
102+
"when the constant strategy is used.")
103+
else:
104+
constant = np.reshape(np.atleast_1d(self.constant), (-1, 1))
105+
if constant.shape[0] != self.n_outputs_:
106+
raise ValueError("Constant target value should have "
107+
"shape (%d, 1)." % self.n_outputs_)
108+
88109
for k in xrange(self.n_outputs_):
89110
classes, y_k = unique(y[:, k], return_inverse=True)
90111
self.classes_.append(classes)
91112
self.n_classes_.append(classes.shape[0])
92113
self.class_prior_.append(np.bincount(y_k) / float(y_k.shape[0]))
93114

115+
# Checking in case of constant strategy if the constant provided
116+
# by the user is in y.
117+
if self.strategy == "constant":
118+
if constant[k] not in self.classes_[k]:
119+
raise ValueError("The constant target value must be "
120+
"present in training data")
121+
94122
if self.n_outputs_ == 1 and not self.output_2d_:
95123
self.n_classes_ = self.n_classes_[0]
96124
self.classes_ = self.classes_[0]
@@ -123,12 +151,13 @@ def predict(self, X):
123151
n_classes_ = self.n_classes_
124152
classes_ = self.classes_
125153
class_prior_ = self.class_prior_
154+
constant = self.constant
126155
if self.n_outputs_ == 1:
127156
# Get same type even for self.n_outputs_ == 1
128157
n_classes_ = [n_classes_]
129158
classes_ = [classes_]
130159
class_prior_ = [class_prior_]
131-
160+
constant = [constant]
132161
# Compute probability only once
133162
if self.strategy == "stratified":
134163
proba = self.predict_proba(X)
@@ -146,6 +175,10 @@ def predict(self, X):
146175
elif self.strategy == "uniform":
147176
ret = rs.randint(n_classes_[k], size=n_samples)
148177

178+
elif self.strategy == "constant":
179+
ret = np.ones(n_samples, dtype=int) * (
180+
np.where(classes_[k] == constant[k]))
181+
149182
y.append(classes_[k][ret])
150183

151184
y = np.vstack(y).T
@@ -181,11 +214,13 @@ def predict_proba(self, X):
181214
n_classes_ = self.n_classes_
182215
classes_ = self.classes_
183216
class_prior_ = self.class_prior_
217+
constant = self.constant
184218
if self.n_outputs_ == 1 and not self.output_2d_:
185219
# Get same type even for self.n_outputs_ == 1
186220
n_classes_ = [n_classes_]
187221
classes_ = [classes_]
188222
class_prior_ = [class_prior_]
223+
constant = [constant]
189224

190225
P = []
191226
for k in xrange(self.n_outputs_):
@@ -201,6 +236,11 @@ def predict_proba(self, X):
201236
out = np.ones((n_samples, n_classes_[k]), dtype=np.float64)
202237
out /= n_classes_[k]
203238

239+
elif self.strategy == "constant":
240+
ind = np.where(classes_[k] == constant[k])
241+
out = np.zeros((n_samples, n_classes_[k]), dtype=np.float64)
242+
out[:, ind] = 1.0
243+
204244
P.append(out)
205245

206246
if self.n_outputs_ == 1 and not self.output_2d_:

sklearn/tests/test_dummy.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,48 @@ def test_multioutput_regressor():
208208
def test_regressor_exceptions():
209209
reg = DummyRegressor()
210210
assert_raises(ValueError, reg.predict, [])
211+
212+
213+
def test_constant_strategy():
214+
X = [[0], [0], [0], [0]] # ignored
215+
y = [2, 1, 2, 2]
216+
217+
clf = DummyClassifier(strategy="constant", random_state=0, constant=1)
218+
clf.fit(X, y)
219+
assert_array_equal(clf.predict(X), np.ones(len(X)))
220+
_check_predict_proba(clf, X, y)
221+
222+
X = [[0], [0], [0], [0]] # ignored
223+
y = ['two', 'one', 'two', 'two']
224+
clf = DummyClassifier(strategy="constant", random_state=0, constant='one')
225+
clf.fit(X, y)
226+
assert_array_equal(clf.predict(X), np.array(['one']*4))
227+
_check_predict_proba(clf, X, y)
228+
229+
230+
def test_constant_strategy_multioutput():
231+
X = [[0], [0], [0], [0]] # ignored
232+
y = np.array([[2, 3],
233+
[1, 3],
234+
[2, 3],
235+
[2, 0]])
236+
237+
n_samples = len(X)
238+
239+
clf = DummyClassifier(strategy="constant", random_state=0,
240+
constant=[1, 0])
241+
clf.fit(X, y)
242+
assert_array_equal(clf.predict(X),
243+
np.hstack([np.ones((n_samples, 1)),
244+
np.zeros((n_samples, 1))]))
245+
_check_predict_proba(clf, X, y)
246+
247+
248+
def test_constant_strategy_exceptions():
249+
X = [[0], [0], [0], [0]] # ignored
250+
y = [2, 1, 2, 2]
251+
clf = DummyClassifier(strategy="constant", random_state=0)
252+
assert_raises(ValueError, clf.fit, X, y)
253+
clf = DummyClassifier(strategy="constant", random_state=0,
254+
constant=[2,0])
255+
assert_raises(ValueError, clf.fit, X, y)

0 commit comments

Comments
 (0)
0