@@ -90,7 +90,11 @@ def _check_reg_targets(y_true, y_pred, multioutput):
90
90
n_outputs = y_true .shape [1 ]
91
91
multioutput_options = (None , 'raw_values' , 'uniform_average' ,
92
92
'variance_weighted' )
93
- if multioutput not in multioutput_options :
93
+ if isinstance (multioutput , string_types ) and \
94
+ multioutput not in multioutput_options :
95
+ raise ValueError ("Invalid multioutput option" )
96
+
97
+ elif multioutput is not None and not isinstance (multioutput , string_types ):
94
98
multioutput = check_array (multioutput , ensure_2d = False )
95
99
if n_outputs == 1 :
96
100
raise ValueError ("Custom weights are useful only in "
@@ -505,7 +509,8 @@ def r2_score(y_true, y_pred,
505
509
0.948...
506
510
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
507
511
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
508
- >>> r2_score(y_true, y_pred, multioutput='variance_weighted') # doctest: +ELLIPSIS
512
+ # doctest: +ELLIPSIS
513
+ >>> r2_score(y_true, y_pred, multioutput='variance_weighted')
509
514
0.938...
510
515
>>> y_true = [1,2,3]
511
516
>>> y_pred = [1,2,3]
0 commit comments