8000 [MRG+1] Add new regression metric - Mean Squared Log Error (#7655) · paulha/scikit-learn@99d74b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 99d74b4

Browse files
Karan Desaipaulha
authored andcommitted
[MRG+1] Add new regression metric - Mean Squared Log Error (scikit-learn#7655)
* ENH Implement mean squared log error in sklearn.metrics.regression * TST Add tests for mean squared log error. * DOC Write user guide and docstring about mean squared log error. * ENH Add neg_mean_squared_log_error in metrics.scorer
1 parent ad83b63 commit 99d74b4

File tree

7 files changed

+148
-8
lines changed

7 files changed

+148
-8
lines changed

doc/modules/classes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ details.
844844
metrics.explained_variance_score
845845
metrics.mean_absolute_error
846846
metrics.mean_squared_error
847+
metrics.mean_squared_log_error
847848
metrics.median_absolute_error
848849
metrics.r2_score
849850

@@ -1418,4 +1419,4 @@ To be removed in 0.20
14181419
cross_validation.cross_val_score
14191420
cross_validation.check_cv
14201421
cross_validation.permutation_test_score
1421-
cross_validation.train_test_split
1422+
cross_validation.train_test_split

doc/modules/model_evaluation.rst

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Scoring Function Co
7777
**Regression**
7878
'neg_mean_absolute_error' :func:`metrics.mean_absolute_error`
7979
'neg_mean_squared_error' :func:`metrics.mean_squared_error`
80+
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
8081
'neg_median_absolute_error' :func:`metrics.median_absolute_error`
8182
'r2' :func:`metrics.r2_score`
8283
=========================== ========================================= ==================================
@@ -93,7 +94,7 @@ Usage examples:
9394
>>> model = svm.SVC()
9495
>>> cross_val_score(model, X, y, scoring='wrong_choice')
9596
Traceback (most recent call last):
96-
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc']
97+
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc']
9798

9899
.. note::
99100

@@ -1360,7 +1361,7 @@ Mean squared error
13601361

13611362
The :func:`mean_squared_error` function computes `mean square
13621363
error <https://en.wikipedia.org/wiki/Mean_squared_error>`_, a risk
1363-
metric corresponding to the expected value of the squared (quadratic) error loss or
1364+
metric corresponding to the expected value of the squared (quadratic) error or
13641365
loss.
13651366

13661367
If :math:`\hat{y}_i` is the predicted value of the :math:`i`-th sample,
@@ -1390,6 +1391,43 @@ function::
13901391
for an example of mean squared error usage to
13911392
evaluate gradient boosting regression.
13921393

1394+
.. _mean_squared_log_error:
1395+
1396+
Mean squared logarithmic error
1397+
------------------------------
1398+
1399+
The :func:`mean_squared_log_error` function computes a risk metric
1400+
corresponding to the expected value of the squared logarithmic (quadratic)
1401+
error or loss.
1402+
1403+
If :math:`\hat{y}_i` is the predicted value of the :math:`i`-th sample,
1404+
and :math:`y_i` is the corresponding true value, then the mean squared
1405+
logarithmic error (MSLE) estimated over :math:`n_{\text{samples}}` is
1406+
defined as
1407+
1408+
.. math::
1409+
1410+
\text{MSLE}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples} - 1} (\log_e (1 + y_i) - \log_e (1 + \hat{y}_i) )^2.
1411+
1412+
Where :math:`\log_e (x)` means the natural logarithm of :math:`x`. This metric
1413+
is best to use when targets having exponential growth, such as population
1414+
counts, average sales of a commodity over a span of years etc. Note that this
1415+
metric penalizes an under-predicted estimate greater than an over-predicted
1416+
estimate.
1417+
1418+
Here is a small example of usage of the :func:`mean_squared_log_error`
1419+
function::
1420+
1421+
>>> from sklearn.metrics import mean_squared_log_error
1422+
>>> y_true = [3, 5, 2.5, 7]
1423+
>>> y_pred = [2.5, 5, 4, 8]
1424+
>>> mean_squared_log_error(y_true, y_pred) # doctest: +ELLIPSIS
1425+
0.039...
1426+
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
1427+
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
1428+
>>> mean_squared_log_error(y_true, y_pred) # doctest: +ELLIPSIS
1429+
0.044...
1430+
13931431
.. _median_absolute_error:
13941432

13951433
Median absolute error

sklearn/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from< 8000 /span> .regression import explained_variance_score
5555
from .regression import mean_absolute_error
5656
from .regression import mean_squared_error
57+
from .regression import mean_squared_log_error
5758
from .regression import median_absolute_error
5859
from .regression import r2_score
5960

@@ -90,6 +91,7 @@
9091
'matthews_corrcoef',
9192
'mean_absolute_error',
9293
'mean_squared_error',
94+
'mean_squared_log_error',
9395
'median_absolute_error',
9496
'mutual_info_score',
9597
'normalized_mutual_info_score',

sklearn/metrics/regression.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# Jochen Wersdorfer <jochen@wersdoerfer.de>
1515
# Lars Buitinck
1616
# Joel Nothman <joel.nothman@gmail.com>
17+
# Karan Desai <karandesai281196@gmail.com>
1718
# Noel Dawe <noel@dawe.me>
1819
# Manoj Kumar <manojkumarsivaraj334@gmail.com>
1920
# Michael Eickenberg <michael.eickenberg@gmail.com>
@@ -33,6 +34,7 @@
3334
__ALL__ = [
3435
"mean_absolute_error",
3536
"mean_squared_error",
37+
"mean_squared_log_error",
3638
"median_absolute_error",
3739
"r2_score",
3840
"explained_variance_score"
@@ -241,6 +243,73 @@ def mean_squared_error(y_true, y_pred,
241243
return np.average(output_errors, weights=multioutput)
242244

243245

246+
def mean_squared_log_error(y_true, y_pred,
247+
sample_weight=None,
248+
multioutput='uniform_average'):
249+
"""Mean squared logarithmic error regression loss
250+
251+
Read more in the :ref:`User Guide <mean_squared_log_error>`.
252+
253+
Parameters
254+
----------
255+
y_true : array-like of shape = (n_samples) or (n_samples, n_outputs)
256+
Ground truth (correct) target values.
257+
258+
y_pred : array-like of shape = (n_samples) or (n_samples, n_outputs)
259+
Estimated target values.
260+
261+
sample_weight : array-like of shape = (n_samples), optional
262+
Sample weights.
263+
264+
multioutput : string in ['raw_values', 'uniform_average'] \
265+
or array-like of shape = (n_outputs)
266+
267+
Defines aggregating of multiple output values.
268+
Array-like value defines weights used to average errors.
269+
270+
'raw_values' :
271+
Returns a full set of errors when the input is of multioutput
272+
format.
273+
274+
'uniform_average' :
275+
Errors of all outputs are averaged with uniform weight.
276+
277+
Returns
278+
-------
279+
loss : float or ndarray of floats
280+
A non-negative floating point value (the best value is 0.0), or an
281+
array of floating point values, one for each individual target.
282+
283+
Examples
284+
--------
285+
>>> from sklearn.metrics import mean_squared_log_error
286+
>>> y_true = [3, 5, 2.5, 7]
287+
>>> y_pred = [2.5, 5, 4, 8]
288+
>>> mean_squared_log_error(y_true, y_pred) # doctest: +ELLIPSIS
289+
0.039...
290+
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
291+
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
292+
>>> mean_squared_log_error(y_true, y_pred) # doctest: +ELLIPSIS
293+
0.044...
294+
>>> mean_squared_log_error(y_true, y_pred, multioutput='raw_values')
295+
... # doctest: +ELLIPSIS
296+
array([ 0.004..., 0.083...])
297+
>>> mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
298+
... # doctest: +ELLIPSIS
299+
0.060...
300+
301+
"""
302+
y_type, y_true, y_pred, multioutput = _check_reg_targets(
303+
y_true, y_pred, multioutput)
304+
305+
if not (y_true >= 0).all() and not (y_pred >= 0).all():
306+
raise ValueError("Mean Squared Logarithmic Error cannot be used when "
307+
"targets contain negative values.")
308+
309+
return mean_squared_error(np.log(y_true + 1), np.log(y_pred + 1),
310+
sample_weight, multioutput)
311+
312+
244313
def median_absolute_error(y_true, y_pred):
245314
"""Median absolute error regression loss
246315

sklearn/metrics/scorer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import numpy as np
2525

2626
from . import (r2_score, median_absolute_error, mean_absolute_error,
27-
mean_squared_error, accuracy_score, f1_score,
28-
roc_auc_score, average_precision_score,
27+
mean_squared_error, mean_squared_log_error, accuracy_score,
28+
f1_score, roc_auc_score, average_precision_score,
2929
precision_score, recall_score, log_loss)
3030
from .cluster import adjusted_rand_score
3131
from ..utils.multiclass import type_of_target
@@ -349,6 +349,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
349349
mean_squared_error_scorer = make_scorer(mean_squared_error,
350350
greater_is_better=False)
351351
mean_squared_error_scorer._deprecation_msg = deprecation_msg
352+
neg_mean_squared_log_error_scorer = make_scorer(mean_squared_log_error,
353+
greater_is_better=False)
352354
neg_mean_absolute_error_scorer = make_scorer(mean_absolute_error,
353355
greater_is_better=False)
354356
deprecation_msg = ('Scoring method mean_absolute_error was renamed to '
@@ -396,6 +398,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
396398
neg_median_absolute_error=neg_median_absolute_error_scorer,
397399
neg_mean_absolute_error=neg_mean_absolute_error_scorer,
398400
neg_mean_squared_error=neg_mean_squared_error_scorer,
401+
neg_mean_squared_log_error=neg_mean_squared_log_error_scorer,
399402
median_absolute_error=median_absolute_error_scorer,
400403
mean_absolute_error=mean_absolute_error_scorer,
401404
mean_squared_error=mean_squared_error_scorer,

sklearn/metrics/tests/test_regression.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from itertools import product
55

6-
from sklearn.utils.testing import assert_raises
6+
from sklearn.utils.testing import assert_raises, assert_raises_regex
77
from sklearn.utils.testing import assert_equal
88
from sklearn.utils.testing import assert_almost_equal
99
from sklearn.utils.testing import assert_array_equal
@@ -12,6 +12,7 @@
1212
from sklearn.metrics import explained_variance_score
1313
from sklearn.metrics import mean_absolute_error
1414
from sklearn.metrics import mean_squared_error
15+
from sklearn.metrics import mean_squared_log_error
1516
from sklearn.metrics import median_absolute_error
1617
from sklearn.metrics import r2_score
1718

@@ -23,6 +24,9 @@ def test_regression_metrics(n_samples=50):
2324
y_pred = y_true + 1
2425

2526
assert_almost_equal(mean_squared_error(y_true, y_pred), 1.)
27+
assert_almost_equal(mean_squared_log_error(y_true, y_pred),
28+
mean_squared_error(np.log(1 + y_true),
29+
np.log(1 + y_pred)))
2630
assert_almost_equal(mean_absolute_error(y_true, y_pred), 1.)
2731
assert_almost_equal(median_absolute_error(y_true, y_pred), 1.)
2832
assert_almost_equal(r2_score(y_true, y_pred), 0.995, 2)
@@ -36,6 +40,9 @@ def test_multioutput_regression():
3640
error = mean_squared_error(y_true, y_pred)
3741
assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.)
3842

43+
error = mean_squared_log_error(y_true, y_pred)
44+
assert_almost_equal(error, 0.200, decimal=2)
45+
3946
# mean_absolute_error and mean_squared_error are equal because
4047
# it is a binary problem.
4148
error = mean_absolute_error(y_true, y_pred)
@@ -49,10 +56,14 @@ def test_multioutput_regression():
4956

5057
def test_regression_metrics_at_limits():
5158
assert_almost_equal(mean_squared_error([0.], [0.]), 0.00, 2)
59+
assert_almost_equal(mean_squared_log_error([0.], [0.]), 0.00, 2)
5260
assert_almost_equal(mean_absolute_error([0.], [0.]), 0.00, 2)
5361
assert_almost_equal(median_absolute_error([0.], [0.]), 0.00, 2)
5462
assert_almost_equal(explained_variance_score([0.], [0.]), 1.00, 2)
5563
assert_almost_equal(r2_score([0., 1], [0., 1]), 1.00, 2)
64+
assert_raises_regex(ValueError, "Mean Squared Logarithmic Error cannot be "
65+
"used when targets contain negative values.",
66+
mean_squared_log_error, [-1.], [-1.])
5667

5768

5869
def test__check_reg_targets():
@@ -127,6 +138,14 @@ def test_regression_multioutput_array():
127138
assert_array_almost_equal(evs, [1., -3.], decimal=2)
128139
assert_equal(np.mean(evs), explained_variance_score(y_true, y_pred))
129140

141+
# Handling msle separately as it does not accept negative inputs.
142+
y_true = np.array([[0.5, 1], [1, 2], [7, 6]])
143+
y_pred = np.array([[0.5, 2], [1, 2.5], [8, 8]])
144+
msle = mean_squared_log_error(y_true, y_pred, multioutput='raw_values')
145+
msle2 = mean_squared_error(np.log(1 + y_true), np.log(1 + y_pred),
146+
multioutput='raw_values')
147+
assert_array_almost_equal(msle, msle2, decimal=2)
148+
130149

131150
def test_regression_custom_weights():
132151
y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]]
@@ -141,3 +160,11 @@ def test_regression_custom_weights():
141160
assert_almost_equal(maew, 0.475, decimal=3)
142161
assert_almost_equal(rw, 0.94, decimal=2)
143162
assert_almost_equal(evsw, 0.94, decimal=2)
163+
164+
# Handling msle separately as it does not accept negative inputs.
165+
y_true = np.array([[0.5, 1], [1, 2], [7, 6]])
166+
y_pred = np.array([[0.5, 2], [1, 2.5], [8, 8]])
167+
msle = mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
168+
msle2 = mean_squared_error(np.log(1 + y_true), np.log(1 + y_pred),
169+
multioutput=[0.3, 0.7])
170+
assert_almost_equal(msle, msle2, decimal=2)

sklearn/metrics/tests/test_score_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939

4040

4141
REGRESSION_SCORERS = ['r2', 'neg_mean_absolute_error',
42-
'neg_mean_squared_error', 'neg_median_absolute_error',
43-
'mean_absolute_error',
42+
'neg_mean_squared_error', 'neg_mean_squared_log_error',
43+
'neg_median_absolute_error', 'mean_absolute_error',
4444
'mean_squared_error', 'median_absolute_error']
4545

4646
CLF_SCORERS = ['accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro',

0 commit comments

Comments
 (0)
0