8000 API replace mean_squared_error(square=False) by root_mean_squared_err… · glemaitre/scikit-learn@c86219f · GitHub
[go: up one dir, main page]

Skip to content

Commit c86219f

Browse files
101AlexMartinAlexMGTNOjeremiedbbglemaitre
committed
API replace mean_squared_error(square=False) by root_mean_squared_error (scikit-learn#26734)
Co-authored-by: Alejandro Martin <alejandro.martingil@tno.nl> Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2a06ca7 commit c86219f

File tree

8 files changed

+298
-26
lines changed

8 files changed

+298
-26
lines changed

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,8 @@ details.
995995
metrics.median_absolute_error
996996
metrics.mean_absolute_percentage_error
997997
metrics.r2_score
998+
metrics.root_mean_squared_log_error
999+
metrics.root_mean_squared_error
9981000
metrics.mean_poisson_deviance
9991001
metrics.mean_gamma_deviance
10001002
metrics.mean_tweedie_deviance

doc/modules/model_evaluation.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ Scoring Function
9494
'max_error' :func:`metrics.max_error`
9595
'neg_mean_absolute_error' :func:`metrics.mean_absolute_error`
9696
'neg_mean_squared_error' :func:`metrics.mean_squared_error`
97-
'neg_root_mean_squared_error' :func:`metrics.mean_squared_error`
97+
'neg_root_mean_squared_error' :func:`metrics.root_mean_squared_error`
9898
'neg_mean_squared_log_error' :func:`metrics.mean_squared_log_error`
99+
'neg_root_mean_squared_log_error' :func:`metrics.root_mean_squared_log_error`
99100
'neg_median_absolute_error' :func:`metrics.median_absolute_error`
100101
'r2' :func:`metrics.r2_score`
101102
'neg_mean_poisson_deviance' :func:`metrics.mean_poisson_deviance`
@@ -2310,6 +2311,10 @@ function::
23102311
for an example of mean squared error usage to
23112312
evaluate gradient boosting regression.
23122313

2314+
Taking the square root of the MSE, called the root mean squared error (RMSE), is another
2315+
common metric that provides a measure in the same units as the target variable. RSME is
2316+
available through the :func:`root_mean_squared_error` function.
2317+
23132318
.. _mean_squared_log_error:
23142319

23152320
Mean squared logarithmic error
@@ -2347,6 +2352,9 @@ function::
23472352
>>> mean_squared_log_error(y_true, y_pred)
23482353
0.044...
23492354

2355+
The root mean squared logarithmic error (RMSLE) is available through the
2356+
:func:`root_mean_squared_log_error` function.
2357+
23502358
.. _mean_absolute_percentage_error:
23512359

23522360
Mean absolute percentage error

sklearn/metrics/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
mean_tweedie_deviance,
6363
median_absolute_error,
6464
r2_score,
65+
root_mean_squared_error,
66+
root_mean_squared_log_error,
6567
)
6668
from ._scorer import check_scoring, get_scorer, get_scorer_names, make_scorer
6769
from .cluster import (
@@ -166,6 +168,8 @@
166168
"RocCurveDisplay",
167169
"roc_auc_score",
168170
"roc_curve",
171+
"root_mean_squared_log_error",
172+
"root_mean_squared_error",
169173
"get_scorer_names",
170174
"silhouette_samples",
171175
"silhouette_score",

sklearn/metrics/_regression.py

Lines changed: 207 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# Uttam kumar <bajiraouttamsinha@gmail.com>
2525
# Sylvain Marie <sylvain.marie@se.com>
2626
# Ohad Michel <ohadmich@gmail.com>
27+
# Alejandro Martin Gil <almagil98@gmail.com>
2728
# License: BSD 3 clause
2829

2930
import warnings
@@ -33,7 +34,7 @@
3334
from scipy.special import xlogy
3435

3536
from ..exceptions import UndefinedMetricWarning
36-
from ..utils._param_validation import Interval, StrOptions, validate_params
37+
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
3738
from ..utils.stats import _weighted_percentile
3839
from ..utils.validation import (
3940
_check_sample_weight,
@@ -52,6 +53,8 @@
5253
"mean_absolute_percentage_error",
5354
"mean_pinball_loss",
5455
"r2_score",
56+
"root_mean_squared_log_error",
57+
"root_mean_squared_error",
5558
"explained_variance_score",
5659
"mean_tweedie_deviance",
5760
"mean_poisson_deviance",
@@ -407,12 +410,17 @@ def mean_absolute_percentage_error(
407410
"y_pred": ["array-like"],
408411
"sample_weight": ["array-like", None],
409412
"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],
410-
"squared": ["boolean"],
413+
"squared": [Hidden(StrOptions({"deprecated"})), "boolean"],
411414
},
412415
prefer_skip_nested_validation=True,
413416
)
414417
def mean_squared_error(
415-
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average", squared=True
418+
y_true,
419+
y_pred,
420+
*,
421+
sample_weight=None,
422+
multioutput="uniform_average",
423+
squared="deprecated",
416424
):
417425
"""Mean squared error regression loss.
418426
@@ -443,6 +451,11 @@ def mean_squared_error(
443451
squared : bool, default=True
444452
If True returns MSE value, if False returns RMSE value.
445453
454+
.. deprecated:: 1.4
455+
`squared` is deprecated in 1.4 and will be removed in 1.6.
456+
Use :func:`~sklearn.metrics.root_mean_squared_error`
457+
instead to calculate the root mean squared error.
458+
446459
Returns
447460
-------
448461
loss : float or ndarray of floats
@@ -456,29 +469,110 @@ def mean_squared_error(
456469
>>> y_pred = [2.5, 0.0, 2, 8]
457470
>>> mean_squared_error(y_true, y_pred)
458471
0.375
459-
>>> y_true = [3, -0.5, 2, 7]
460-
>>> y_pred = [2.5, 0.0, 2, 8]
461-
>>> mean_squared_error(y_true, y_pred, squared=False)
462-
0.612...
463472
>>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
464473
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
465474
>>> mean_squared_error(y_true, y_pred)
466475
0.708...
467-
>>> mean_squared_error(y_true, y_pred, squared=False)
468-
0.822...
469476
>>> mean_squared_error(y_true, y_pred, multioutput='raw_values')
470477
array([0.41666667, 1. ])
471478
>>> mean_squared_error(y_true, y_pred, multioutput=[0.3, 0.7])
472479
0.825...
473480
"""
481+
# TODO(1.6): remove
482+
if squared != "deprecated":
483+
warnings.warn(
484+
(
485+
"'squared' is deprecated in version 1.4 and "
486+
"will be removed in 1.6. To calculate the "
487+
"root mean squared error, use the function"
488+
"'root_mean_squared_error'."
489+
),
490+
FutureWarning,
491+
)
492+
if not squared:
493+
return root_mean_squared_error(
494+
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
495+
)
496+
474497
y_type, y_true, y_pred, multioutput = _check_reg_targets(
475498
y_true, y_pred, multioutput
476499
)
477500
check_consistent_length(y_true, y_pred, sample_weight)
478501
output_errors = np.average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)
479502

480-
if not squared:
481-
output_errors = np.sqrt(output_errors)
503+
if isinstance(multioutput, str):
504+
if 10000 multioutput == "raw_values":
505+
return output_errors
506+
elif multioutput == "uniform_average":
507+
# pass None as weights to np.average: uniform mean
508+
multioutput = None
509+
510+
return np.average(output_errors, weights=multioutput)
511+
512+
513+
@validate_params(
514+
{
515+
"y_true": ["array-like"],
516+
"y_pred": ["array-like"],
517+
"sample_weight": ["array-like", None],
518+
"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],
519+
},
520+
prefer_skip_nested_validation=True,
521+
)
522+
def root_mean_squared_error(
523+
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
524+
):
525+
"""Root mean squared error regression loss.
526+
527+
Read more in the :ref:`User Guide <mean_squared_error>`.
528+
529+
.. versionadded:: 1.4
530+
531+
Parameters
532+
----------
533+
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
534+
Ground truth (correct) target values.
535+
536+
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
537+
Estimated target values.
538+
539+
sample_weight : array-like of shape (n_samples,), default=None
540+
Sample weights.
541+
542+
multioutput : {'raw_values', 'uniform_average'} or array-like of shape \
543+
(n_outputs,), default='uniform_average'
544+
Defines aggregating of multiple output values.
545+
Array-like value defines weights used to average errors.
546+
547+
'raw_values' :
548+
Returns a full set of errors in case of multioutput input.
549+
550+
'uniform_average' :
551+
Errors of all outputs are averaged with uniform weight.
552+
553+
Returns
554+
-------
555+
loss : float or ndarray of floats
556+
A non-negative floating point value (the best value is 0.0), or an
557+
array of floating point values, one for each individual target.
558+
559+
Examples
560+
--------
561+
>>> from sklearn.metrics import root_mean_squared_error
562+
>>> y_true = [3, -0.5, 2, 7]
563+
>>> y_pred = [2.5, 0.0, 2, 8]
564+
>>> root_mean_squared_error(y_true, y_pred)
565+
0.612...
566+
>>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
567+
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
568+
>>> root_mean_squared_error(y_true, y_pred)
569+
0.822...
570+
"""
571+
output_errors = np.sqrt(
572+
mean_squared_error(
573+
y_true, y_pred, sample_weight=sample_weight, multioutput="raw_values"
574+
)
575+
)
482576

483577
if isinstance(multioutput, str):
484578
if multioutput == "raw_values":
@@ -496,12 +590,17 @@ def mean_squared_error(
496590
"y_pred": ["array-like"],
497591
"sample_weight": ["array-like", None],
498592
"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],
499-
"squared": ["boolean"],
593+
"squared": [Hidden(StrOptions({"deprecated"})), "boolean"],
500594
},
501595
prefer_skip_nested_validation=True,
502596
)
503597
def mean_squared_log_error(
504-
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average", squared=True
598+
y_true,
599+
y_pred,
600+
*,
601+
sample_weight=None,
602+
multioutput="uniform_average",
603+
squared="deprecated",
505604
):
506605
"""Mean squared logarithmic error regression loss.
507606
@@ -530,10 +629,16 @@ def mean_squared_log_error(
530629
531630
'uniform_average' :
532631
Errors of all outputs are averaged with uniform weight.
632+
533633
squared : bool, default=True
534634
If True returns MSLE (mean squared log error) value.
535635
If False returns RMSLE (root mean squared log error) value.
536636
637+
.. deprecated:: 1.4
638+
`squared` is deprecated in 1.4 and will be removed in 1.6.
639+
Use :func:`~sklearn.metrics.root_mean_squared_log_error`
640+
instead to calculate the root mean squared logarithmic error.
641+
537642
Returns
538643
-------
539644
loss : float or ndarray of floats
@@ -547,8 +652,6 @@ def mean_squared_log_error(
547652
>>> y_pred = [2.5, 5, 4, 8]
548653
>>> mean_squared_log_error(y_true, y_pred)
549654
0.039...
550-
>>> mean_squared_log_error(y_true, y_pred, squared=False)
551-
0.199...
552655
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
553656
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
554657
>>> mean_squared_log_error(y_true, y_pred)
@@ -558,6 +661,22 @@ def mean_squared_log_error(
558661
>>> mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
559662
0.060...
560663
"""
664+
# TODO(1.6): remove
665+
if squared != "deprecated":
666+
warnings.warn(
667+
(
668+
"'squared' is deprecated in version 1.4 and "
669+
"will be removed in 1.6. To calculate the "
670+
"root mean squared logarithmic error, use the function"
671+
"'root_mean_squared_log_error'."
672+
),
673+
FutureWarning,
674+
)
675+
if not squared:
676+
return root_mean_squared_log_error(
677+
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
678+
)
679+
561680
y_type, y_true, y_pred, multioutput = _check_reg_targets(
562681
y_true, y_pred, multioutput
563682
)
@@ -574,7 +693,79 @@ def mean_squared_log_error(
574693
np.log1p(y_pred),
575694
sample_weight=sample_weight,
576695
multioutput=multioutput,
577-
squared=squared,
696+
)
697+
698+
699+
@validate_params(
700+
{
701+
"y_true": ["array-like"],
702+
"y_pred": ["array-like"],
703+
"sample_weight": ["array-like", None],
704+
"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],
705+
},
706+
prefer_skip_nested_validation=True,
707+
)
708+
def root_mean_squared_log_error(
709+
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
710+
):
711+
"""Root mean squared logarithmic error regression loss.
712+
713+
Read more in the :ref:`User Guide <mean_squared_log_error>`.
714+
715+
.. versionadded:: 1.4
716+
717+
Parameters
718+
----------
719+
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
720+
Ground truth (correct) target values.
721+
722+
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
723+
Estimated target values.
724+
725+
sample_weight : array-like of shape (n_samples,), default=None
726+
Sample weights.
727+
728+
multioutput : {'raw_values', 'uniform_average'} or array-like of shape \
729+
(n_outputs,), default='uniform_average'
730+
731+
Defines aggregating of multiple output values.
732+
Array-like value defines weights used to average errors.
733+
734+
'raw_values' :
735+
Returns a full set of errors when the input is of multioutput
736+
format.
737+
738+
'uniform_average' :
739+
Errors of all outputs are averaged with uniform weight.
740+
741+
Returns
742+
-------
743+
loss : float or ndarray of floats
744+
A non-negative floating point value (the best value is 0.0), or an
745+
array of floating point values, one for each individual target.
746+
747+
Examples
748+
--------
749+
>>> from sklearn.metrics import root_mean_squared_log_error
750+
>>> y_true = [3, 5, 2.5, 7]
751+
>>> y_pred = [2.5, 5, 4, 8]
752+
>>> root_mean_squared_log_error(y_true, y_pred)
753+
0.199...
754+
"""
755+
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
756+
check_consistent_length(y_true, y_pred, sample_weight)
757+
758+
if (y_true < 0).any() or (y_pred < 0).any():
759+
raise ValueError(
760+
"Root Mean Squared Logarithmic Error cannot be used when "
761+
"targets contain negative values."
762+
)
763+
764+
return root_mean_squared_error(
765+
np.log1p(y_true),
766+
np.log1p(y_pred),
767+
sample_weight=sample_weight,
768+
multioutput=multioutput,
578769
)
579770

580771

0 commit comments

Comments
 (0)
0