8000 Merge pull request #2886 from maheshakya/dummy_regressor · rmurcek/scikit-learn@302a3c0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 302a3c0

Browse files
committed
Merge pull request scikit-learn#2886 from maheshakya/dummy_regressor
[MRG+1] ENH Implemented median and constant strategies in DummyRegressor
2 parents 6091d7d + 3412f29 commit 302a3c0

File tree

3 files changed

+214
-20
lines changed

3 files changed

+214
-20
lines changed

doc/modules/model_evaluation.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,5 +1146,11 @@ classification, it probably means that something went wrong: features are not
11461146
helpful, a hyper parameter is not correctly tuned, the classifier is suffering
11471147
from class imbalance, etc...
11481148

1149-
:class:`DummyRegressor` implements a simple rule of thumb for regression:
1150-
always predict the mean of the training targets.
1149+
:class:`DummyRegressor` also implements three simple rules of thumb for regression:
1150+
1151+
- `mean` always predicts the mean of the training targets.
1152+
- `median` always predicts the median of the training targests.
1153+
- `constant` always predicts a constant value that is provided by the user.
1154+
1155+
In all these strategies, the `predict` method completely ignores
1156+
the input data.

sklearn/dummy.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Author: Mathieu Blondel <mathieu@mblondel.org>
22
# Arnaud Joly <a.joly@ulg.ac.be>
3+
# Maheshakya Wijewardena<maheshakya.10@cse.mrt.ac.lk>
34
# License: BSD 3 clause
45

56
import numpy as np
@@ -8,6 +9,7 @@
89
from .externals.six.moves import xrange
910
from .utils import check_random_state
1011
from .utils.validation import safe_asarray
12+
from sklearn.utils import deprecated
1113

1214

1315
class DummyClassifier(BaseEstimator, ClassifierMixin):
@@ -272,16 +274,30 @@ def predict_log_proba(self, X):
272274

273275
class DummyRegressor(BaseEstimator, RegressorMixin):
274276
"""
275-
DummyRegressor is a regressor that always predicts the mean of the training
276-
targets.
277+
DummyRegressor is a regressor that makes predictions using
278+
simple rules.
277279
278280
This regressor is useful as a simple baseline to compare with other
279281
(real) regressors. Do not use it for real problems.
280282
283+
Parameters
284+
----------
285+
strategy: str
286+
Strategy to use to generate predictions.
287+
* "mean": always predicts the mean of the training set
288+
* "median": always predicts the median of the training set
289+
* "constant": always predicts a constant value that is provided by
290+
the user.
291+
292+
constant: int or float or array of shape = [n_outputs]
293+
The explicit constant as predicted by the "constant" strategy. This
294+
parameter is useful only for the "constant" strategy.
295+
281296
Attributes
282297
----------
283-
`y_mean_` : float or array of shape [n_outputs]
284-
Mean of the training targets.
298+
`constant_' : float or array of shape [n_outputs]
299+
Mean or median of the training targets or constant value given the by
300+
the user.
285301
286302
`n_outputs_` : int,
287303
Number of outputs.
@@ -290,6 +306,17 @@ class DummyRegressor(BaseEstimator, RegressorMixin):
290306
True if the output at fit is 2d, else false.
291307
"""
292308

309+
def __init__(self, strategy="mean", constant=None):
310+
self.strategy = strategy
311+
self.constant = constant
312+
313+
@property
314+
@deprecated('This will be removed in version 0.17')
315+
def y_mean_(self):
316+
if self.strategy == 'mean':
317+
return self.constant_
318+
raise AttributeError
319+
293320
def fit(self, X, y):
294321
"""Fit the random regressor.
295322
@@ -307,10 +334,36 @@ def fit(self, X, y):
307334
self : object
308335
Returns self.
309336
"""
337+
338+
if self.strategy not in ("mean", "median", "constant"):
339+
raise ValueError("Unknown strategy type: %s, "
340+
"expected 'mean', 'median' or 'constant'"
341+
% self.strategy)
342+
310343
y = safe_asarray(y)
311-
self.y_mean_ = np.reshape(np.mean(y, axis=0), (1, -1))
312-
self.n_outputs_ = np.size(self.y_mean_) # y.shape[1] is not safe
313344
self.output_2d_ = (y.ndim == 2)
345+
346+
if self.strategy == "mean":
347+
self.constant_ = np.reshape(np.mean(y, axis=0), (1, -1))
348+
349+
elif self.strategy == "median":
350+
self.constant_ = np.reshape(np.median(y, axis=0), (1, -1))
351+
352+
elif self.strategy == "constant":
353+
if self.constant is None:
354+
raise TypeError("Constant target value has to be specified "
355+
"when the constant strategy is used.")
356+
357+
self.constant = safe_asarray(self.constant)
358+
359+
if self.output_2d_ and self.constant.shape[0] != y.shape[1]:
360+
raise ValueError(
361+
"Constant target value should have "
362+
"shape (%d, 1)." % y.shape[1])
363+
364+
self.constant_ = np.reshape(self.constant, (1, -1))
365+
366+
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
314367
return self
315368

316369
def predict(self, X):
@@ -328,12 +381,13 @@ def predict(self, X):
328381
y : array, shape = [n_samples] or [n_samples, n_outputs]
329382
Predicted target values for X.
330383
"""
331-
if not hasattr(self, "y_mean_"):
384+
if not hasattr(self, "constant_"):
332385
raise ValueError("DummyRegressor not fitted.")
333386

334387
X = safe_asarray(X)
335388
n_samples = X.shape[0]
336-
y = np.ones((n_samples, 1)) * self.y_mean_
389+
390+
y = np.ones((n_samples, 1)) * self.constant_
337391

338392
if self.n_outputs_ == 1 and not self.output_2d_:
339393
y = np.ravel(y)

sklearn/tests/test_dummy.py

Lines changed: 144 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ def _check_behavior_2d(clf):
5959
assert_equal(y.shape, y_pred.shape)
6060

6161

62+
def _check_behavior_2d_for_constant(clf):
63+
# 2d case only
64+
X = np.array([[0], [0], [0], [0]]) # ignored
65+
y = np.array([[1, 0, 5, 4, 3],
66+
[2, 0, 1, 2, 5],
67+
[1, 0, 4, 5, 2],
68+
[1, 3, 3, 2, 0]])
69+
est = clone(clf)
70+
est.fit(X, y)
71+
y_pred = est.predict(X)
72+
assert_equal(y.shape, y_pred.shape)
73+
74+
75+
def _check_equality_regressor(statistic, y_learn, y_pred_learn,
76+
y_test, y_pred_test):
77+
assert_array_equal(np.tile(statistic, (y_learn.shape[0], 1)),
78+
y_pred_learn)
79+
assert_array_equal(np.tile(statistic, (y_test.shape[0], 1)),
80+
y_pred_test)
81+
82+
6283
def test_most_frequent_strategy():
6384
X = [[0], [0], [0], [0]] # ignored
6485
y = [1, 2, 1, 1]
@@ -175,33 +196,37 @@ def test_classifier_exceptions():
175196
assert_raises(ValueError, clf.predict_proba, [])
176197

177198

178-
def test_regressor():
199+
def test_mean_strategy_regressor():
200+
201+
random_state = np.random.RandomState(seed=1)
202+
179203
X = [[0]] * 4 # ignored
180-
y = [1, 2, 1, 1]
204+
y = random_state.randn(4)
181205

182206
reg = DummyRegressor()
183207
reg.fit(X, y)
184-
assert_array_equal(reg.predict(X), [5. / 4] * len(X))
208+
assert_array_equal(reg.predict(X), [np.mean(y)] * len(X))
185209

186210

187-
def test_multioutput_regressor():
211+
def test_mean_strategy_multioutput_regressor():
188212

189-
X_learn = np.random.randn(10, 10)
190-
y_learn = np.random.randn(10, 5)
213+
random_state = np.random.RandomState(seed=1)
214+
215+
X_learn = random_state.randn(10, 10)
216+
y_learn = random_state.randn(10, 5)
191217

192218
mean = np.mean(y_learn, axis=0).reshape((1, -1))
193219

194-
X_test = np.random.randn(20, 10)
195-
y_test = np.random.randn(20, 5)
220+
X_test = random_state.randn(20, 10)
221+
y_test = random_state.randn(20, 5)
196222

197223
# Correctness oracle
198224
est = DummyRegressor()
199225
est.fit(X_learn, y_learn)
200226
y_pred_learn = est.predict(X_learn)
201227
y_pred_test = est.predict(X_test)
202228

203-
assert_array_equal(np.tile(mean, (y_learn.shape[0], 1)), y_pred_learn)
204-
assert_array_equal(np.tile(mean, (y_test.shape[0], 1)), y_pred_test)
229+
_check_equality_regressor(mean, y_learn, y_pred_learn, y_test, y_pred_test)
205230
_check_behavior_2d(est)
206231

207232

@@ -210,6 +235,115 @@ def test_regressor_exceptions():
210235
assert_raises(ValueError, reg.predict, [])
211236

212237

238+
def test_median_strategy_regressor():
239+
240+
random_state = np.random.RandomState(seed=1)
241+
242+
X = [[0]] * 5 # ignored
243+
y = random_state.randn(5)
244+
245+
reg = DummyRegressor(strategy="median")
246+
reg.fit(X, y)
247+
assert_array_equal(reg.predict(X), [np.median(y)] * len(X))
248+
249+
250+
def test_median_strategy_multioutput_regressor():
251+
252+
random_state = np.random.RandomState(seed=1)
253+
254+
X_learn = random_state.randn(10, 10)
255+
y_learn = random_state.randn(10, 5)
256+
257+
median = np.median(y_learn, axis=0).reshape((1, -1))
258+
259+
X_test = random_state.randn(20, 10)
260+
y_test = random_state.randn(20, 5)
261+
262+
# Correctness oracle
263+
est = DummyRegressor(strategy="median")
264+
est.fit(X_learn, y_learn)
265+
y_pred_learn = est.predict(X_learn)
266+
y_pred_test = est.predict(X_test)
267+
268+
_check_equality_regressor(
269+
median, y_learn, y_pred_learn, y_test, y_pred_test)
270+
_check_behavior_2d(est)
271+
272+
273+
def test_constant_strategy_regressor():
274+
275+
random_state = np.random.RandomState(seed=1)
276+
277+
X = [[0]] * 5 # ignored
278+
y = random_state.randn(5)
279+
280+
reg = DummyRegressor(strategy="constant", constant=[43])
281+
reg.fit(X, y)
282+
assert_array_equal(reg.predict(X), [43] * len(X))
283+
284+
reg = DummyRegressor(strategy="constant", constant=43)
285+
reg.fit(X, y)
286+
assert_array_equal(reg.predict(X), [43] * len(X))
287+
288+
289+
def test_constant_strategy_multioutput_regressor():
290+
291+
random_state = np.random.RandomState(seed=1)
292+
293+
X_learn = random_state.randn(10, 10)
294+
y_learn = random_state.randn(10, 5)
295+
296+
# test with 2d array
297+
constants = random_state.randn(5)
298+
299+
X_test = random_state.randn(20, 10)
300+
y_test = random_state.randn(20, 5)
301+
302+
# Correctness oracle
303+
est = DummyRegressor(strategy="constant", constant=constants)
304+
est.fit(X_learn, y_learn)
305+
y_pred_learn = est.predict(X_learn)
306+
y_pred_test = est.predict(X_test)
307+
308+
_check_equality_regressor(
309+
constants, y_learn, y_pred_learn, y_test, y_pred_test)
310+
_check_behavior_2d_for_constant(est)
311+
312+
313+
def test_y_mean_attribute_regressor():
314+
X = [[0]] * 5
315+
y = [1, 2, 4, 6, 8]
316+
# when strategy = 'mean'
317+
est = DummyRegressor(strategy='mean')
318+
est.fit(X, y)
319+
assert_equal(est.y_mean_, np.mean(y))
320+
321+
322+
def test_unknown_strategey_regressor():
323+
X = [[0]] * 5
324+
y = [1, 2, 4, 6, 8]
325+
326+
est = DummyRegressor(strategy='gona')
327+
assert_raises(ValueError, est.fit, X, y)
328+
329+
330+
def test_constants_not_specified_regressor():
331+
X = [[0]] * 5
332+
y = [1, 2, 4, 6, 8]
333+
334+
est = DummyRegressor(strategy='constant')
335+
assert_raises(TypeError, est.fit, X, y)
336+
337+
338+
def test_constant_size_multioutput_regressor():
339+
random_state = np.random.RandomState(seed=1)
340+
X = random_state.randn(10, 10)
341+
y = random_state.randn(10, 5)
342+
343+
est = DummyRegressor(strategy='constant', constant=[1, 2, 3, 4])
344+
assert_raises(ValueError, est.fit, X, y)
345+
346+
213347
def test_constant_strategy():
214348
X = [[0], [0], [0], [0]] # ignored
215349
y = [2, 1, 2, 2]

0 commit comments

Comments
 (0)
0