8000 FIX don't compare things that can be arrays to strings. · scikit-learn/scikit-learn@058e919 · GitHub
[go: up one dir, main page]

Skip to content

Commit 058e919

Browse files
committed
FIX don't compare things that can be arrays to strings.
1 parent 48a2329 commit 058e919

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

sklearn/metrics/regression.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def mean_absolute_error(y_true, y_pred,
169169
elif multioutput == 'uniform_average':
170170
# pass None as weights to np.average: uniform mean
171171
multioutput = None
172+
else:
173+
raise ValueError("Unknown multioutput type: %s" % multioutput)
172174

173175
return np.average(output_errors, weights=multioutput)
174176

@@ -237,6 +239,8 @@ def mean_squared_error(y_true, y_pred,
237239
elif multioutput == 'uniform_average':
238240
# pass None as weights to np.average: uniform mean
239241
multioutput = None
242+
else:
243+
raise ValueError("Unknown multioutput type: %s" % multioutput)
240244

241245
return np.average(output_errors, weights=multioutput)
242246

@@ -352,14 +356,17 @@ def explained_variance_score(y_true, y_pred,
352356
output_scores[valid_score] = 1 - (numerator[valid_score] /
353357
denominator[valid_score])
354358
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.
355-
if multioutput == 'raw_values':
356-
# return scores individually
357-
return output_scores
358-
elif multioutput == 'uniform_average':
359-
# passing to np.average() None as weights results is uniform mean
360-
avg_weights = None
361-
elif multioutput == 'variance_weighted':
362-
avg_weights = denominator
359+
if isinstance(multioutput, string_types):
360+
if multioutput == 'raw_values':
361+
# return scores individually
362+
return output_scores
363+
elif multioutput == 'uniform_average':
364+
# passing to np.average() None as weights results is uniform mean
365+
avg_weights = None
366+
elif multioutput == 'variance_weighted':
367+
avg_weights = denominator
368+
else:
369+
raise ValueError("Unknown multioutput type: %s" % multioutput)
363370
else:
364371
avg_weights = multioutput
365372

@@ -394,7 +401,7 @@ def r2_score(y_true, y_pred,
394401
Defines aggregating of multiple output scores.
395402
Array-like value defines weights used to average scores.
396403
Default value correponds to 'variance_weighted', this behaviour is
397-
deprecated since version 0.17 and will be changed to 'uniform_average'
404+
deprecated since version 0.17 and will be changed to 'uniform_average'
398405
starting from 0.19.
399406
400407
'raw_values' :
@@ -483,6 +490,8 @@ def r2_score(y_true, y_pred,
483490
return 1.0
484491
else:
485492
return 0.0
493+
else:
494+
raise ValueError("Unknown multioutput type: %s" % multioutput)
486495
else:
487496
avg_weights = multioutput
488497

0 commit comments

Comments
 (0)
0