21
21
# Konstantin Shmelkov <konstantin.shmelkov@polytechnique.edu>
22
22
# Christian Lorentzen <lorentzen.ch@googlemail.com>
23
23
# Ashutosh Hathidara <ashutoshhathidara98@gmail.com>
24
+ # Uttam kumar <bajiraouttamsinha@gmail.com>
24
25
# License: BSD 3 clause
25
26
26
27
import numpy as np
@@ -437,7 +438,7 @@ def mean_squared_error(
437
438
438
439
439
440
def mean_squared_log_error (
440
- y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average"
441
+ y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average" , squared = True
441
442
):
442
443
"""Mean squared logarithmic error regression loss.
443
444
@@ -466,6 +467,9 @@ def mean_squared_log_error(
466
467
467
468
'uniform_average' :
468
469
Errors of all outputs are averaged with uniform weight.
470
+ squared : bool, default=True
471
+ If True returns MSLE (mean squared log error) value.
472
+ If False returns RMSLE (root mean squared log error) value.
469
473
470
474
Returns
471
475
-------
@@ -480,6 +484,8 @@ def mean_squared_log_error(
480
484
>>> y_pred = [2.5, 5, 4, 8]
481
485
>>> mean_squared_log_error(y_true, y_pred)
482
486
0.039...
487
+ >>> mean_squared_log_error(y_true, y_pred, squared=False)
488
+ 0.199...
483
489
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
484
490
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
485
491
>>> mean_squared_log_error(y_true, y_pred)
@@ -505,6 +511,7 @@ def mean_squared_log_error(
505
511
np .log1p (y_pred ),
506
512
sample_weight = sample_weight ,
507
513
multioutput = multioutput ,
514
+ squared = squared ,
508
515
)
509
516
510
517
0 commit comments