8000 ENH implement multioutput support for median_absolute_error (#14732) · jeremiedbb/scikit-learn@d026e32 · GitHub
[go: up one dir, main page]

Skip to content

Commit d026e32

Browse filesBrowse files
agamemnoncglemaitre
authored andcommitted
ENH implement multioutput support for median_absolute_error (scikit-learn#14732)
1 parent ad0e9a9 commit d026e32

File tree

4 files changed

+50
-14
lines changed

4 files changed

+50
-14
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,10 @@ Changelog
460460
:func:`metrics.pairwise.manhattan_distances` in the case of sparse matrices.
461461
:pr:`15049` by `Paolo Toccaceli <ptocca>`.
462462

463+
- |Enhancement| :func:`metrics.median_absolute_error` now supports
464+
``multioutput`` parameter.
465+
:pr:`14732` by :user:`Agamemnon Krasoulis <agamemnonc>`.
466+
463467
:mod:`sklearn.model_selection`
464468
..............................
465469

@@ -663,7 +667,7 @@ Changelog
663667
- |Fix| :func:`utils.check_array` will now correctly detect numeric dtypes in
664668
pandas dataframes, fixing a bug where ``float32`` was upcast to ``float64``
665669
unnecessarily. :pr:`15094` by `Andreas Müller`_.
666-
670+
667671
- |API| The following utils have been deprecated and are now private:
668672
- ``choose_check_classifiers_labels``
669673
- ``enforce_estimator_tags_y``
@@ -719,4 +723,3 @@ These changes mostly affect library developers.
719723
:pr:`13392` by :user:`Rok Mihevc <rok>`.
720724

721725
- |Fix| Added ``check_transformer_data_not_an_array`` to checks where missing
722-

sklearn/metrics/regression.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -330,23 +330,38 @@ def mean_squared_log_error(y_true, y_pred,
330330
sample_weight, multioutput)
331331

332332

333-
def median_absolute_error(y_true, y_pred):
333+
def median_absolute_error(y_true, y_pred, multioutput='uniform_average'):
334334
"""Median absolute error regression loss
335335
336-
Read more in the :ref:`User Guide <median_absolute_error>`.
336+
Median absolute error output is non-negative floating point. The best value
337+
is 0.0. Read more in the :ref:`User Guide <median_absolute_error>`.
337338
338339
Parameters
339340
----------
340-
y_true : array-like of shape (n_samples,)
341+
y_true : array-like of shape = (n_samples) or (n_samples, n_outputs)
341342
Ground truth (correct) target values.
342343
343-
y_pred : array-like of shape (n_samples,)
344+
y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs)
344345
Estimated target values.
345346
347+
multioutput : {'raw_values', 'uniform_average'} or array-like of shape
348+
(n_outputs,)
349+
Defines aggregating of multiple output values. Array-like value defines
350+
weights used to average errors.
351+
352+
'raw_values' :
353+
Returns a full set of errors in case of multioutput input.
354+
355+
'uniform_average' :
356+
Errors of all outputs are averaged with uniform weight.
357+
346358
Returns
347359
-------
348-
loss : float
349-
A positive floating point value (the best value is 0.0).
360+
loss : float or ndarray of floats
361+
If multioutput is 'raw_values', then mean absolute error is returned
362+
for each output separately.
363+
If multioutput is 'uniform_average' or an ndarray of weights, then the
364+
weighted average of all output errors is returned.
350365
351366
Examples
352367
--------
@@ -355,12 +370,27 @@ def median_absolute_error(y_true, y_pred):
355370
>>> y_pred = [2.5, 0.0, 2, 8]
356371
>>> median_absolute_error(y_true, y_pred)
357372
0.5
373+
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
374+
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
375+
>>> median_absolute_error(y_true, y_pred)
376+
0.75
377+
>>> median_absolute_error(y_true, y_pred E30A , multioutput='raw_values')
378+
array([0.5, 1. ])
379+
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
380+
0.85
358381
359382
"""
360-
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None)
361-
if y_type == 'continuous-multioutput':
362-
raise ValueError("Multioutput not supported in median_absolute_error")
363-
return np.median(np.abs(y_pred - y_true))
383+
y_type, y_true, y_pred, multioutput = _check_reg_targets(
384+
y_true, y_pred, multioutput)
385+
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
386+
if isinstance(multioutput, str):
387+
if multioutput == 'raw_values':
388+
return output_errors
389+
elif multioutput == 'uniform_average':
390+
# pass None as weights to np.average: uniform mean
391+
multioutput = None
392+
393+
return np.average(output_errors, weights=multioutput)
364394

365395

366396
def explained_variance_score(y_true, y_pred,

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
426426

427427
# Regression metrics with "multioutput-continuous" format support
428428
MULTIOUTPUT_METRICS = {
429-
"mean_absolute_error", "mean_squared_error", "r2_score",
430-
"explained_variance_score"
429+
"mean_absolute_error", "median_absolute_error", "mean_squared_error",
430+
"r2_score", "explained_variance_score"
431431
}
432432

433433
# Symmetric with respect to their input arguments y_true and y_pred

sklearn/metrics/tests/test_regression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def test_multioutput_regression():
7474
error = mean_absolute_error(y_true, y_pred)
7575
assert_almost_equal(error, (1. + 2. / 3) / 4.)
7676

77+
error = median_absolute_error(y_true, y_pred)
78+
assert_almost_equal(error, (1. + 1.) / 4.)
79+
7780
error = r2_score(y_true, y_pred, multioutput='variance_weighted')
7881
assert_almost_equal(error, 1. - 5. / 2)
7982
error = r2_score(y_true, y_pred, multioutput='uniform_average')

0 commit comments

Comments
 (0)
0