8000 [MRG+2] median_absolute_error multioutput by agamemnonc · Pull Request #14732 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] median_absolute_error multioutput #14732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
7 changes: 5 additions & 2 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ Changelog
:func:`metrics.pairwise.manhattan_distances` in the case of sparse matrices.
:pr:`15049` by `Paolo Toccaceli <ptocca>`.

- |Enhancement| :func:`metrics.median_absolute_error` now supports
``multioutput`` parameter.
:pr:`14732` by :user:`Agamemnon Krasoulis <agamemnonc>`.

:mod:`sklearn.model_selection`
..............................

Expand Down Expand Up @@ -663,7 +667,7 @@ Changelog
- |Fix| :func:`utils.check_array` will now correctly detect numeric dtypes in
pandas dataframes, fixing a bug where ``float32`` was upcast to ``float64``
unnecessarily. :pr:`15094` by `Andreas Müller`_.

- |API| The following utils have been deprecated and are now private:
- ``choose_check_classifiers_labels``
- ``enforce_estimator_tags_y``
Expand Down Expand Up @@ -719,4 +723,3 @@ These changes mostly affect library developers.
:pr:`13392` by :user:`Rok Mihevc <rok>`.

- |Fix| Added ``check_transformer_data_not_an_array`` to checks where missing

50 changes: 40 additions & 10 deletions sklearn/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,38 @@ def mean_squared_log_error(y_true, y_pred,
sample_weight, multioutput)


def median_absolute_error(y_true, y_pred):
def median_absolute_error(y_true, y_pred, multioutput='uniform_average'):
"""Median absolute error regression loss

Read more in the :ref:`User Guide <median_absolute_error>`.
Median absolute error output is non-negative floating point. The best value
is 0.0. Read more in the :ref:`User Guide <median_absolute_error>`.

Parameters
----------
y_true : array-like of shape (n_samples,)
y_true : array-like of shape = (n_samples) or (n_samples, n_outputs)
Ground truth (correct) target values.

y_pred : array-like of shape (n_samples,)
y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs)
Estimated target values.

multioutput : {'raw_values', 'uniform_average'} or array-like of shape
(n_outputs,)
Defines aggregating of multiple output values. Array-like value defines
weights used to average errors.

'raw_values' :
Returns a full set of errors in case of multioutput input.

'uniform_average' :
Errors of all outputs are averaged with uniform weight.

Returns
-------
loss : float
A positive floating point value (the best value is 0.0).
loss : float or ndarray of floats
If multioutput is 'raw_values', then mean absolute error is returned
for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
weighted average of all output errors is returned.

Examples
--------
Expand All @@ -355,12 +370,27 @@ def median_absolute_error(y_true, y_pred):
>>> y_pred = [2.5, 0.0, 2, 8]
>>> median_absolute_error(y_true, y_pred)
0.5
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
>>> median_absolute_error(y_true, y_pred)
0.75
>>> median_absolute_error(y_true, y_pred, multioutput='raw_values')
array([0.5, 1. ])
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85

"""
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None)
if y_type == 'continuous-multioutput':
raise ValueError("Multioutput not supported in median_absolute_error")
return np.median(np.abs(y_pred - y_true))
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors
elif multioutput == 'uniform_average':
# pass None as weights to np.average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)


def explained_variance_score(y_true, y_pred,
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):

# Regression metrics with "multioutput-continuous" format support
MULTIOUTPUT_METRICS = {
"mean_absolute_error", "mean_squared_error", "r2_score",
"explained_variance_score"
"mean_absolute_error", "median_absolute_error", "mean_squared_error",
"r2_score", "explained_variance_score"
}

# Symmetric with respect to their input arguments y_true and y_pred
Expand Down
3 changes: 3 additions & 0 deletions sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def test_multioutput_regression():
error = mean_absolute_error(y_true, y_pred)
assert_almost_equal(error, (1. + 2. / 3) / 4.)

error = median_absolute_error(y_true, y_pred)
assert_almost_equal(error, (1. + 1.) / 4.)

error = r2_score(y_true, y_pred, multioutput='variance_weighted')
assert_almost_equal(error, 1. - 5. / 2)
error = r2_score(y_true, y_pred, multioutput='uniform_average')
Expand Down
0