24
24
# Uttam kumar <bajiraouttamsinha@gmail.com>
25
25
# Sylvain Marie <sylvain.marie@se.com>
26
26
# Ohad Michel <ohadmich@gmail.com>
27
+ # Alejandro Martin Gil <almagil98@gmail.com>
27
28
# License: BSD 3 clause
28
29
29
30
import warnings
33
34
from scipy .special import xlogy
34
35
35
36
from ..exceptions import UndefinedMetricWarning
36
- from ..utils ._param_validation import Interval , StrOptions , validate_params
37
+ from ..utils ._param_validation import Hidden , Interval , StrOptions , validate_params
37
38
from ..utils .stats import _weighted_percentile
38
39
from ..utils .validation import (
39
40
_check_sample_weight ,
52
53
"mean_absolute_percentage_error" ,
53
54
"mean_pinball_loss" ,
54
55
"r2_score" ,
56
+ "root_mean_squared_log_error" ,
57
+ "root_mean_squared_error" ,
55
58
"explained_variance_score" ,
56
59
"mean_tweedie_deviance" ,
57
60
"mean_poisson_deviance" ,
@@ -407,12 +410,17 @@ def mean_absolute_percentage_error(
407
410
"y_pred" : ["array-like" ],
408
411
"sample_weight" : ["array-like" , None ],
409
412
"multioutput" : [StrOptions ({"raw_values" , "uniform_average" }), "array-like" ],
410
- "squared" : ["boolean" ],
413
+ "squared" : [Hidden ( StrOptions ({ "deprecated" })), "boolean" ],
411
414
},
412
415
prefer_skip_nested_validation = True ,
413
416
)
414
417
def mean_squared_error (
415
- y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average" , squared = True
418
+ y_true ,
419
+ y_pred ,
420
+ * ,
421
+ sample_weight = None ,
422
+ multioutput = "uniform_average" ,
423
+ squared = "deprecated" ,
416
424
):
417
425
"""Mean squared error regression loss.
418
426
@@ -443,6 +451,11 @@ def mean_squared_error(
443
451
squared : bool, default=True
444
452
If True returns MSE value, if False returns RMSE value.
445
453
454
+ .. deprecated:: 1.4
455
+ `squared` is deprecated in 1.4 and will be removed in 1.6.
456
+ Use :func:`~sklearn.metrics.root_mean_squared_error`
457
+ instead to calculate the root mean squared error.
458
+
446
459
Returns
447
460
-------
448
461
loss : float or ndarray of floats
@@ -456,29 +469,110 @@ def mean_squared_error(
456
469
>>> y_pred = [2.5, 0.0, 2, 8]
457
470
>>> mean_squared_error(y_true, y_pred)
458
471
0.375
459
- >>> y_true = [3, -0.5, 2, 7]
460
- >>> y_pred = [2.5, 0.0, 2, 8]
461
- >>> mean_squared_error(y_true, y_pred, squared=False)
462
- 0.612...
463
472
>>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
464
473
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
465
474
>>> mean_squared_error(y_true, y_pred)
466
475
0.708...
467
- >>> mean_squared_error(y_true, y_pred, squared=False)
468
- 0.822...
469
476
>>> mean_squared_error(y_true, y_pred, multioutput='raw_values')
470
477
array([0.41666667, 1. ])
471
478
>>> mean_squared_error(y_true, y_pred, multioutput=[0.3, 0.7])
472
479
0.825...
473
480
"""
481
+ # TODO(1.6): remove
482
+ if squared != "deprecated" :
483
+ warnings .warn (
484
+ (
485
+ "'squared' is deprecated in version 1.4 and "
486
+ "will be removed in 1.6. To calculate the "
487
+ "root mean squared error, use the function"
488
+ "'root_mean_squared_error'."
489
+ ),
490
+ FutureWarning ,
491
+ )
492
+ if not squared :
493
+ return root_mean_squared_error (
494
+ y_true , y_pred , sample_weight = sample_weight , multioutput = multioutput
495
+ )
496
+
474
497
y_type , y_true , y_pred , multioutput = _check_reg_targets (
475
498
y_true , y_pred , multioutput
476
499
)
477
500
check_consistent_length (y_true , y_pred , sample_weight )
478
501
output_errors = np .average ((y_true - y_pred ) ** 2 , axis = 0 , weights = sample_weight )
479
502
480
- if not squared :
481
- output_errors = np .sqrt (output_errors )
503
+ if isinstance (multioutput , str ):
504
+ if
10000
multioutput == "raw_values" :
505
+ return output_errors
506
+ elif multioutput == "uniform_average" :
507
+ # pass None as weights to np.average: uniform mean
508
+ multioutput = None
509
+
510
+ return np .average (output_errors , weights = multioutput )
511
+
512
+
513
+ @validate_params (
514
+ {
515
+ "y_true" : ["array-like" ],
516
+ "y_pred" : ["array-like" ],
517
+ "sample_weight" : ["array-like" , None ],
518
+ "multioutput" : [StrOptions ({"raw_values" , "uniform_average" }), "array-like" ],
519
+ },
520
+ prefer_skip_nested_validation = True ,
521
+ )
522
+ def root_mean_squared_error (
523
+ y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average"
524
+ ):
525
+ """Root mean squared error regression loss.
526
+
527
+ Read more in the :ref:`User Guide <mean_squared_error>`.
528
+
529
+ .. versionadded:: 1.4
530
+
531
+ Parameters
532
+ ----------
533
+ y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
534
+ Ground truth (correct) target values.
535
+
536
+ y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
537
+ Estimated target values.
538
+
539
+ sample_weight : array-like of shape (n_samples,), default=None
540
+ Sample weights.
541
+
542
+ multioutput : {'raw_values', 'uniform_average'} or array-like of shape \
543
+ (n_outputs,), default='uniform_average'
544
+ Defines aggregating of multiple output values.
545
+ Array-like value defines weights used to average errors.
546
+
547
+ 'raw_values' :
548
+ Returns a full set of errors in case of multioutput input.
549
10000
td>+
550
+ 'uniform_average' :
551
+ Errors of all outputs are averaged with uniform weight.
552
+
553
+ Returns
554
+ -------
555
+ loss : float or ndarray of floats
556
+ A non-negative floating point value (the best value is 0.0), or an
557
+ array of floating point values, one for each individual target.
558
+
559
+ Examples
560
+ --------
561
+ >>> from sklearn.metrics import root_mean_squared_error
562
+ >>> y_true = [3, -0.5, 2, 7]
563
+ >>> y_pred = [2.5, 0.0, 2, 8]
564
+ >>> root_mean_squared_error(y_true, y_pred)
565
+ 0.612...
566
+ >>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
567
+ >>> y_pred = [[0, 2],[-1, 2],[8, -5]]
568
+ >>> root_mean_squared_error(y_true, y_pred)
569
+ 0.822...
570
+ """
571
+ output_errors = np .sqrt (
572
+ mean_squared_error (
573
+ y_true , y_pred , sample_weight = sample_weight , multioutput = "raw_values"
574
+ )
575
+ )
482
576
483
577
if isinstance (multioutput , str ):
484
578
if multioutput == "raw_values" :
@@ -496,12 +590,17 @@ def mean_squared_error(
496
590
"y_pred" : ["array-like" ],
497
591
"sample_weight" : ["array-like" , None ],
498
592
"multioutput" : [StrOptions ({"raw_values" , "uniform_average" }), "array-like" ],
499
- "squared" : ["boolean" ],
593
+ "squared" : [Hidden ( StrOptions ({ "deprecated" })), "boolean" ],
500
594
},
501
595
prefer_skip_nested_validation = True ,
502
596
)
503
597
def mean_squared_log_error (
504
- y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average" , squared = True
598
+ y_true ,
599
+ y_pred ,
600
+ * ,
601
+ sample_weight = None ,
602
+ multioutput = "uniform_average" ,
603
+ squared = "deprecated" ,
505
604
):
506
605
"""Mean squared logarithmic error regression loss.
507
606
@@ -530,10 +629,16 @@ def mean_squared_log_error(
530
629
531
630
'uniform_average' :
532
631
Errors of all outputs are averaged with uniform weight.
632
+
533
633
squared : bool, default=True
534
634
If True returns MSLE (mean squared log error) value.
535
635
If False returns RMSLE (root mean squared log error) value.
536
636
637
+ .. deprecated:: 1.4
638
+ `squared` is deprecated in 1.4 and will be removed in 1.6.
639
+ Use :func:`~sklearn.metrics.root_mean_squared_log_error`
640
+ instead to calculate the root mean squared logarithmic error.
641
+
537
642
Returns
538
643
-------
539
644
loss : float or ndarray of floats
@@ -547,8 +652,6 @@ def mean_squared_log_error(
547
652
>>> y_pred = [2.5, 5, 4, 8]
548
653
>>> mean_squared_log_error(y_true, y_pred)
549
654
0.039...
550
- >>> mean_squared_log_error(y_true, y_pred, squared=False)
551
- 0.199...
552
655
>>> y_true = [[0.5, 1], [1, 2], [7, 6]]
553
656
>>> y_pred = [[0.5, 2], [1, 2.5], [8, 8]]
554
657
>>> mean_squared_log_error(y_true, y_pred)
@@ -558,6 +661,22 @@ def mean_squared_log_error(
558
661
>>> mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
559
662
0.060...
560
663
"""
664
+ # TODO(1.6): remove
665
+ if squared != "deprecated" :
666
+ warnings .warn (
667
+ (
668
+ "'squared' is deprecated in version 1.4 and "
669
+ "will be removed in 1.6. To calculate the "
670
+ "root mean squared logarithmic error, use the function"
671
+ "'root_mean_squared_log_error'."
672
+ ),
673
+ FutureWarning ,
674
+ )
675
+ if not squared :
676
+ return root_mean_squared_log_error (
677
+ y_true , y_pred , sample_weight = sample_weight , multioutput = multioutput
678
+ )
679
+
561
680
y_type , y_true , y_pred , multioutput = _check_reg_targets (
562
681
y_true , y_pred , multioutput
563
682
)
@@ -574,7 +693,79 @@ def mean_squared_log_error(
574
693
np .log1p (y_pred ),
575
694
sample_weight = sample_weight ,
576
695
multioutput = multioutput ,
577
- squared = squared ,
696
+ )
697
+
698
+
699
+ @validate_params (
700
+ {
701
+ "y_true" : ["array-like" ],
702
+ "y_pred" : ["array-like" ],
703
+ "sample_weight" : ["array-like" , None ],
704
+ "multioutput" : [StrOptions ({"raw_values" , "uniform_average" }), "array-like" ],
705
+ },
706
+ prefer_skip_nested_validation = True ,
707
+ )
708
+ def root_mean_squared_log_error (
709
+ y_true , y_pred , * , sample_weight = None , multioutput = "uniform_average"
710
+ ):
711
+ """Root mean squared logarithmic error regression loss.
712
+
713
+ Read more in the :ref:`User Guide <mean_squared_log_error>`.
714
+
715
+ .. versionadded:: 1.4
716
+
717
+ Parameters
718
+ ----------
719
+ y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
720
+ Ground truth (correct) target values.
721
+
722
+ y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
723
+ Estimated target values.
724
+
725
+ sample_weight : array-like of shape (n_samples,), default=None
726
+ Sample weights.
727
+
728
+ multioutput : {'raw_values', 'uniform_average'} or array-like of shape \
729
+ (n_outputs,), default='uniform_average'
730
+
731
+ Defines aggregating of multiple output values.
732
+ Array-like value defines weights used to average errors.
733
+
734
+ 'raw_values' :
735
+ Returns a full set of errors when the input is of multioutput
736
+ format.
737
+
738
+ 'uniform_average' :
739
+ Errors of all outputs are averaged with uniform weight.
740
+
741
+ Returns
742
+ -------
743
+ loss : float or ndarray of floats
744
+ A non-negative floating point value (the best value is 0.0), or an
745
+ array of floating point values, one for each individual target.
746
+
747
+ Examples
748
+ --------
749
+ >>> from sklearn.metrics import root_mean_squared_log_error
750
+ >>> y_true = [3, 5, 2.5, 7]
751
+ >>> y_pred = [2.5, 5, 4, 8]
752
+ >>> root_mean_squared_log_error(y_true, y_pred)
753
+ 0.199...
754
+ """
755
+ _ , y_true , y_pred , multioutput = _check_reg_targets (y_true , y_pred , multioutput )
756
+ check_consistent_length (y_true , y_pred , sample_weight )
757
+
758
+ if (y_true < 0 ).any () or (y_pred < 0 ).any ():
759
+ raise ValueError (
760
+ "Root Mean Squared Logarithmic Error cannot be used when "
761
+ "targets contain negative values."
762
+ )
763
+
764
+ return root_mean_squared_error (
765
+ np .log1p (y_true ),
766
+ np .log1p (y_pred ),
767
+ sample_weight = sample_weight ,
768
+ multioutput = multioutput ,
578
769
)
579
770
580
771
0 commit comments