@@ -87,9 +87,15 @@ def _check_reg_targets(y_true, y_pred, multioutput):
87
87
"({0}!={1})" .format (y_true .shape [1 ], y_pred .shape [1 ]))
88
88
89
89
n_outputs = y_true .shape [1 ]
90
- multioutput_options = (None , 'raw_values' , 'uniform_average' ,
91
- 'variance_weighted' )
92
- if multioutput not in multioutput_options :
90
+ allowed_multioutput_str = ('raw_values' , 'uniform_average' ,
91
+ 'variance_weighted' )
92
+ if isinstance (multioutput , string_types ):
93
+ if multioutput not in allowed_multioutput_str :
94
+ raise ValueError ("Allowed 'multioutput' string values are {}. "
95
+ "You provided multioutput={!r}" .format (
96
+ allowed_multioutput_str ,
97
+ multioutput ))
98
+ elif multioutput is not None :
93
99
multioutput = check_array (multioutput , ensure_2d = False )
94
100
if n_outputs == 1 :
95
101
raise ValueError ("Custom weights are useful only in "
@@ -504,7 +510,8 @@ def r2_score(y_true, y_pred, sample_weight=None,
504
510
0.948...
505
511
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
506
512
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
507
- >>> r2_score(y_true, y_pred, multioutput='variance_weighted') # doctest: +ELLIPSIS
513
+ >>> r2_score(y_true, y_pred, multioutput='variance_weighted')
514
+ ... # doctest: +ELLIPSIS
508
515
0.938...
509
516
>>> y_true = [1,2,3]
510
517
>>> y_pred = [1,2,3]
0 commit comments