@@ -59,6 +59,27 @@ def _check_behavior_2d(clf):
59
59
assert_equal (y .shape , y_pred .shape )
60
60
61
61
62
+ def _check_behavior_2d_for_constant (clf ):
63
+ # 2d case only
64
+ X = np .array ([[0 ], [0 ], [0 ], [0 ]]) # ignored
65
+ y = np .array ([[1 , 0 , 5 , 4 , 3 ],
66
+ [2 , 0 , 1 , 2 , 5 ],
67
+ [1 , 0 , 4 , 5 , 2 ],
68
+ [1 , 3 , 3 , 2 , 0 ]])
69
+ est = clone (clf )
70
+ est .fit (X , y )
71
+ y_pred = est .predict (X )
72
+ assert_equal (y .shape , y_pred .shape )
73
+
74
+
75
+ def _check_equality_regressor (statistic , y_learn , y_pred_learn ,
76
+ y_test , y_pred_test ):
77
+ assert_array_equal (np .tile (statistic , (y_learn .shape [0 ], 1 )),
78
+ y_pred_learn )
79
+ assert_array_equal (np .tile (statistic , (y_test .shape [0 ], 1 )),
80
+ y_pred_test )
81
+
82
+
62
83
def test_most_frequent_strategy ():
63
84
X = [[0 ], [0 ], [0 ], [0 ]] # ignored
64
85
y = [1 , 2 , 1 , 1 ]
@@ -175,33 +196,37 @@ def test_classifier_exceptions():
175
196
assert_raises (ValueError , clf .predict_proba , [])
176
197
177
198
178
- def test_regressor ():
199
+ def test_mean_strategy_regressor ():
200
+
201
+ random_state = np .random .RandomState (seed = 1 )
202
+
179
203
X = [[0 ]] * 4 # ignored
180
- y = [ 1 , 2 , 1 , 1 ]
204
+ y = random_state . randn ( 4 )
181
205
182
206
reg = DummyRegressor ()
183
207
reg .fit (X , y )
184
- assert_array_equal (reg .predict (X ), [5. / 4 ] * len (X ))
208
+ assert_array_equal (reg .predict (X ), [np . mean ( y ) ] * len (X ))
185
209
186
210
187
- def test_multioutput_regressor ():
211
+ def test_mean_strategy_multioutput_regressor ():
188
212
189
- X_learn = np .random .randn (10 , 10 )
190
- y_learn = np .random .randn (10 , 5 )
213
+ random_state = np .random .RandomState (seed = 1 )
214
+
215
+ X_learn = random_state .randn (10 , 10 )
216
+ y_learn = random_state .randn (10 , 5 )
191
217
192
218
mean = np .mean (y_learn , axis = 0 ).reshape ((1 , - 1 ))
193
219
194
- X_test = np . random .randn (20 , 10 )
195
- y_test = np . random .randn (20 , 5 )
220
+ X_test = random_state .randn (20 , 10 )
221
+ y_test = random_state .randn (20 , 5 )
196
222
197
223
# Correctness oracle
198
224
est = DummyRegressor ()
199
225
est .fit (X_learn , y_learn )
200
226
y_pred_learn = est .predict (X_learn )
201
227
y_pred_test = est .predict (X_test )
202
228
203
- assert_array_equal (np .tile (mean , (y_learn .shape [0 ], 1 )), y_pred_learn )
204
- assert_array_equal (np .tile (mean , (y_test .shape [0 ], 1 )), y_pred_test )
229
+ _check_equality_regressor (mean , y_learn , y_pred_learn , y_test , y_pred_test )
205
230
_check_behavior_2d (est )
206
231
207
232
@@ -210,6 +235,115 @@ def test_regressor_exceptions():
210
235
assert_raises (ValueError , reg .predict , [])
211
236
212
237
238
+ def test_median_strategy_regressor ():
239
+
240
+ random_state = np .random .RandomState (seed = 1 )
241
+
242
+ X = [[0 ]] * 5 # ignored
243
+ y = random_state .randn (5 )
244
+
245
+ reg = DummyRegressor (strategy = "median" )
246
+ reg .fit (X , y )
247
+ assert_array_equal (reg .predict (X ), [np .median (y )] * len (X ))
248
+
249
+
250
+ def test_median_strategy_multioutput_regressor ():
251
+
252
+ random_state = np .random .RandomState (seed = 1 )
253
+
254
+ X_learn = random_state .randn (10 , 10 )
255
+ y_learn = random_state .randn (10 , 5 )
256
+
257
+ median = np .median (y_learn , axis = 0 ).reshape ((1 , - 1 ))
258
+
259
+ X_test = random_state .randn (20 , 10 )
260
+ y_test = random_state .randn (20 , 5 )
261
+
262
+ # Correctness oracle
263
+ est = DummyRegressor (strategy = "median" )
264
+ est .fit (X_learn , y_learn )
265
+ y_pred_learn = est .predict (X_learn )
266
+ y_pred_test = est .predict (X_test )
267
+
268
+ _check_equality_regressor (
269
+ median , y_learn , y_pred_learn , y_test , y_pred_test )
270
+ _check_behavior_2d (est )
271
+
272
+
273
+ def test_constant_strategy_regressor ():
274
+
275
+ random_state = np .random .RandomState (seed = 1 )
276
+
277
+ X = [[0 ]] * 5 # ignored
278
+ y = random_state .randn (5 )
279
+
280
+ reg = DummyRegressor (strategy = "constant" , constant = [43 ])
281
+ reg .fit (X , y )
282
+ assert_array_equal (reg .predict (X ), [43 ] * len (X ))
283
+
284
+ reg = DummyRegressor (strategy = "constant" , constant = 43 )
285
+ reg .fit (X , y )
286
+ assert_array_equal (reg .predict (X ), [43 ] * len (X ))
287
+
288
+
289
+ def test_constant_strategy_multioutput_regressor ():
290
+
291
+ random_state = np .random .RandomState (seed = 1 )
292
+
293
+ X_learn = random_state .randn (10 , 10 )
294
+ y_learn = random_state .randn (10 , 5 )
295
+
296
+ # test with 2d array
297
+ constants = random_state .randn (5 )
298
+
299
+ X_test = random_state .randn (20 , 10 )
300
+ y_test = random_state .randn (20 , 5 )
301
+
302
+ # Correctness oracle
303
+ est = DummyRegressor (strategy = "constant" , constant = constants )
304
+ est .fit (X_learn , y_learn )
305
+ y_pred_learn = est .predict (X_learn )
306
+ y_pred_test = est .predict (X_test )
307
+
308
+ _check_equality_regressor (
309
+ constants , y_learn , y_pred_learn , y_test , y_pred_test )
310
+ _check_behavior_2d_for_constant (est )
311
+
312
+
313
+ def test_y_mean_attribute_regressor ():
314
+ X = [[0 ]] * 5
315
+ y = [1 , 2 , 4 , 6 , 8 ]
316
+ # when strategy = 'mean'
317
+ est = DummyRegressor (strategy = 'mean' )
318
+ est .fit (X , y )
319
+ assert_equal (est .y_mean_ , np .mean (y ))
320
+
321
+
322
+ def test_unknown_strategey_regressor ():
323
+ X = [[0 ]] * 5
324
+ y = [1 , 2 , 4 , 6 , 8 ]
325
+
326
+ est = DummyRegressor (strategy = 'gona' )
327
+ assert_raises (ValueError , est .fit , X , y )
328
+
329
+
330
+ def test_constants_not_specified_regressor ():
331
+ X = [[0 ]] * 5
332
+ y = [1 , 2 , 4 , 6 , 8 ]
333
+
334
+ est = DummyRegressor (strategy = 'constant' )
335
+ assert_raises (TypeError , est .fit , X , y )
336
+
337
+
338
+ def test_constant_size_multioutput_regressor ():
339
+ random_state = np .random .RandomState (seed = 1 )
340
+ X = random_state .randn (10 , 10 )
341
+ y = random_state .randn (10 , 5 )
342
+
343
+ est = DummyRegressor (strategy = 'constant' , constant = [1 , 2 , 3 , 4 ])
344
+ assert_raises (ValueError , est .fit , X , y )
345
+
346
+
213
347
def test_constant_strategy ():
214
348
X = [[0 ], [0 ], [0 ], [0 ]] # ignored
215
349
y = [2 , 1 , 2 , 2 ]
0 commit comments