@@ -330,23 +330,38 @@ def mean_squared_log_error(y_true, y_pred,
330
330
sample_weight , multioutput )
331
331
332
332
333
- def median_absolute_error (y_true , y_pred ):
333
+ def median_absolute_error (y_true , y_pred , multioutput = 'uniform_average' ):
334
334
"""Median absolute error regression loss
335
335
336
- Read more in the :ref:`User Guide <median_absolute_error>`.
336
+ Median absolute error output is non-negative floating point. The best value
337
+ is 0.0. Read more in the :ref:`User Guide <median_absolute_error>`.
337
338
338
339
Parameters
339
340
----------
340
- y_true : array-like of shape (n_samples, )
341
+ y_true : array-like of shape = (n_samples) or (n_samples, n_outputs )
341
342
Ground truth (correct) target values.
342
343
343
- y_pred : array-like of shape (n_samples, )
344
+ y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs )
344
345
Estimated target values.
345
346
347
+ multioutput : {'raw_values', 'uniform_average'} or array-like of shape
348
+ (n_outputs,)
349
+ Defines aggregating of multiple output values. Array-like value defines
350
+ weights used to average errors.
351
+
352
+ 'raw_values' :
353
+ Returns a full set of errors in case of multioutput input.
354
+
355
+ 'uniform_average' :
356
+ Errors of all outputs are averaged with uniform weight.
357
+
346
358
Returns
347
359
-------
348
- loss : float
349
- A positive floating point value (the best value is 0.0).
360
+ loss : float or ndarray of floats
361
+ If multioutput is 'raw_values', then mean absolute error is returned
362
+ for each output separately.
363
+ If multioutput is 'uniform_average' or an ndarray of weights, then the
364
+ weighted average of all output errors is returned.
350
365
351
366
Examples
352
367
--------
@@ -355,12 +370,27 @@ def median_absolute_error(y_true, y_pred):
355
370
>>> y_pred = [2.5, 0.0, 2, 8]
356
371
>>> median_absolute_error(y_true, y_pred)
357
372
0.5
373
+ >>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
374
+ >>> y_pred = [[0, 2], [-1, 2], [8, -5]]
375
+ >>> median_absolute_error(y_true, y_pred)
376
+ 0.75
377
+ >>> median_absolute_error(y_true, y_pred
E30A
, multioutput='raw_values')
378
+ array([0.5, 1. ])
379
+ >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
380
+ 0.85
358
381
359
382
"""
360
- y_type , y_true , y_pred , _ = _check_reg_targets (y_true , y_pred , None )
361
- if y_type == 'continuous-multioutput' :
362
- raise ValueError ("Multioutput not supported in median_absolute_error" )
363
- return np .median (np .abs (y_pred - y_true ))
383
+ y_type , y_true , y_pred , multioutput = _check_reg_targets (
384
+ y_true , y_pred , multioutput )
385
+ output_errors = np .median (np .abs (y_pred - y_true ), axis = 0 )
386
+ if isinstance (multioutput , str ):
387
+ if multioutput == 'raw_values' :
388
+ return output_errors
389
+ elif multioutput == 'uniform_average' :
390
+ # pass None as weights to np.average: uniform mean
391
+ multioutput = None
392
+
393
+ return np .average (output_errors , weights = multioutput )
364
394
365
395
366
396
def explained_variance_score (y_true , y_pred ,
0 commit comments