8000 [MRG + 1] Fix failure on numpy master (#8011) · sergeyf/scikit-learn@27fa08e · GitHub
[go: up one dir, main page]

Skip to content

Commit 27fa08e

Browse files
aashilsergeyf
authored andcommitted
[MRG + 1] Fix failure on numpy master (scikit-learn#8011)
Was causing "ValueError: The truth value of an array with more than one element is ambiguous"
1 parent 172853d commit 27fa08e

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

sklearn/metrics/regression.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,15 @@ def _check_reg_targets(y_true, y_pred, multioutput):
8787
"({0}!={1})".format(y_true.shape[1], y_pred.shape[1]))
8888

8989
n_outputs = y_true.shape[1]
90-
multioutput_options = (None, 'raw_values', 'uniform_average',
91-
'variance_weighted')
92-
if multioutput not in multioutput_options:
90+
allowed_multioutput_str = ('raw_values', 'uniform_average',
91+
'variance_weighted')
92+
if isinstance(multioutput, string_types):
93+
if multioutput not in allowed_multioutput_str:
94+
raise ValueError("Allowed 'multioutput' string values are {}. "
95+
"You provided multioutput={!r}".format(
96+
allowed_multioutput_str,
97+
multioutput))
98+
elif multioutput is not None:
9399
multioutput = check_array(multioutput, ensure_2d=False)
94100
if n_outputs == 1:
95101
raise ValueError("Custom weights are useful only in "
@@ -504,7 +510,8 @@ def r2_score(y_true, y_pred, sample_weight=None,
504510
0.948...
505511
>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]
506512
>>> y_pred = [[0, 2], [-1, 2], [8, -5]]
507-
>>> r2_score(y_true, y_pred, multioutput='variance_weighted') # doctest: +ELLIPSIS
513+
>>> r2_score(y_true, y_pred, multioutput='variance_weighted')
514+
... # doctest: +ELLIPSIS
508515
0.938...
509516
>>> y_true = [1,2,3]
510517
>>> y_pred = [1,2,3]

sklearn/metrics/tests/test_regression.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def test__check_reg_targets():
9393
assert_raises(ValueError, _check_reg_targets, y1, y2, None)
9494

9595

96+
def test__check_reg_targets_exception():
97+
invalid_multioutput = 'this_value_is_not_valid'
98+
expected_message = ("Allowed 'multioutput' string values are.+"
99+
"You provided multioutput={!r}".format(
100+
invalid_multioutput))
101+
assert_raises_regex(ValueError, expected_message,
102+
_check_reg_targets,
103+
[1, 2, 3],
104+
[[1], [2], [3]],
105+
invalid_multioutput)
106+
107+
96108
def test_regression_multioutput_array():
97109
y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]]
98110
y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]]

0 commit comments

Comments
 (0)
0