8000 RMSLE (root mean squared log error) (#20326) · samronsin/scikit-learn@2106944 · GitHub 65E5
[go: up one dir, main page]

Skip to content

Commit 2106944

Browse files
helper-uttamogrisel
authored andcommitted
RMSLE (root mean squared log error) (scikit-learn#20326)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 2d56d83 commit 2106944

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

doc/whats_new/v1.0.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ Changelog
437437
quantile regression. :pr:`19415` by :user:`Xavier Dupré <sdpython>`
438438
and :user:`Oliver Grisel <ogrisel>`.
439439

440+
- |Feature| :func:`metrics.mean_squared_log_error` now supports
441+
`squared=False`.
442+
:pr:`20326` by :user:`Uttam kumar <helper-uttam>`.
443+
440444
- |Efficiency| Improved speed of :func:`metrics.confusion_matrix` when labels
441445
are integral.
442446
:pr:`9843` by :user:`Jon Crall <Erotemic>`.

sklearn/metrics/_regression.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# Konstantin Shmelkov <konstantin.shmelkov@polytechnique.edu>
2222
# Christian Lorentzen <lorentzen.ch@googlemail.com>
2323
# Ashutosh Hathidara <ashutoshhathidara98@gmail.com>
24+
# Uttam kumar <bajiraouttamsinha@gmail.com>
2425
# License: BSD 3 clause
2526

2627
import numpy as np
@@ -437,7 +438,7 @@ def mean_squared_error(
437438

438439

439440
def mean_squared_log_error(
440-
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
441+
y_true, y_pred, *, sample_weight=None, multioutput="uniform_average", squared=True
441442
):
442443
"""Mean squared logarithmic error regression loss.
443444
@@ -466,6 +467,9 @@ def mean_squared_log_error(
466467
467468
'uniform_average' :
468469
Errors of all outputs are averaged with uniform weight.
470+
squared : bool, default=True
471+
If True returns MSLE (mean squared log error) value.
472+
If False returns RMSLE (root mean squared log error) value.
469473
470474
Returns
471475
-------
@@ -480,6 +484,8 @@ def mean_squared_log_error(
480484
>>> y_pred = [2.5, 5, 4, 8]
481485
>>> mean_squared_log_error(y_true, y_pred)
482486
0.039...
487+
>>> mean_squared_log_error(y_true, y_pred, squared=False)
488+
0.199...
483489
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
484490
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
485491
>>> mean_squared_log_error(y_true, y_pred)
@@ -505,6 +511,7 @@ def mean_squared_log_error(
505511
np.log1p(y_pred),
506512
sample_weight=sample_weight,
507513
multioutput=multioutput,
514+
squared=squared,
508515
)
509516

510517

0 commit comments

Comments
 (0)
0