8000 FIX don't compare things that can be arrays to strings. · scikit-learn/scikit-learn@550b228 · GitHub
[go: up one dir, main page]

Skip to content

Commit 550b228

Browse files
committed
FIX don't compare things that can be arrays to strings.
1 parent 48a2329 commit 550b228

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

sklearn/metrics/regression.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,15 @@ def explained_variance_score(y_true, y_pred,
352352
output_scores[valid_score] = 1 - (numerator[valid_score] /
353353
denominator[valid_score])
354354
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.
355-
if multioutput == 'raw_values':
356-
# return scores individually
357-
return output_scores
358-
elif multioutput == 'uniform_average':
359-
# passing to np.average() None as weights results is uniform mean
360-
avg_weights = None
361-
elif multioutput == 'variance_weighted':
362-
avg_weights = denominator
355+
if isinstance(multioutput, string_types):
356+
if multioutput == 'raw_values':
357+
# return scores individually
358+
return output_scores
359+
elif multioutput == 'uniform_average':
360+
# passing to np.average() None as weights results is uniform mean
361+
avg_weights = None
362+
elif multioutput == 'variance_weighted':
363+
avg_weights = denominator
363364
else:
364365
avg_weights = multioutput
365366

@@ -394,7 +395,7 @@ def r2_score(y_true, y_pred,
394395
Defines aggregating of multiple output scores.
395396
Array-like value defines weights used to average scores.
396397
Default value correponds to 'variance_weighted', this behaviour is
397-
deprecated since version 0.17 and will be changed to 'uniform_average'
398+
deprecated since version 0.17 and will be changed to 'uniform_average'
398399
starting from 0.19.
399400
400401
'raw_values' :

0 commit comments

Comments
 (0)
0