8000 Merge pull request #12243 from liwt31/fix_misleading_msg · shoyer/numpy@7fbcc4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 7fbcc4e

Browse files
authored
Merge pull request numpy#12243 from liwt31/fix_misleading_msg
BUG: Fix misleading assert message in assert_almost_equal numpy#12200
2 parents 2705bd5 + be5ea7d commit 7fbcc4e

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

numpy/testing/_private/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
692692
x = array(x, copy=False, subok=True)
693693
y = array(y, copy=False, subok=True)
694694

695+
# original array for output formating
696+
ox, oy = x, y
697+
695698
def isnumber(x):
696699
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
697700

@@ -785,10 +788,10 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
785788
# do not trigger a failure (np.ma.masked != True evaluates as
786789
# np.ma.masked, which is falsy).
787790
if cond != True:
788-
match = 100-100.0*reduced.count(1)/len(reduced)
789-
msg = build_err_msg([x, y],
791+
mismatch = 100.0 * reduced.count(0) / ox.size
792+
msg = build_err_msg([ox, oy],
790793
err_msg
791-
+ '\n(mismatch %s%%)' % (match,),
794+
+ '\n(mismatch %s%%)' % (mismatch,),
792795
verbose=verbose, header=header,
793796
names=('x', 'y'), precision=precision)
794797
raise AssertionError(msg)

numpy/testing/tests/test_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,8 @@ def test_complex(self):
507507
self._test_not_equal(x, z)
508508

509509
def test_error_message(self):
510-
"""Check the message is formatted correctly for the decimal value"""
510+
"""Check the message is formatted correctly for the decimal value.
511+
Also check the message when input includes inf or nan (gh12200)"""
511512
x = np.array([1.00000000001, 2.00000000002, 3.00003])
512513
y = np.array([1.00000000002, 2.00000000003, 3.00004])
513514

@@ -531,6 +532,19 @@ def test_error_message(self):
531532
# remove anything that's not the array string
532533
assert_equal(str(e).split('%)\n ')[1], b)
533534

535+
# Check the error message when input includes inf or nan
536+
x = np.array([np.inf, 0])
537+
y = np.array([np.inf, 1])
538+
try:
539+
self._assert_func(x, y)
540+
except AssertionError as e:
541+
msgs = str(e).split('\n')
542+
# assert error percentage is 50%
543+
assert_equal(msgs[3], '(mismatch 50.0%)')
544+
# assert output array contains inf
545+
assert_equal(msgs[4], ' x: array([inf, 0.])')
546+
assert_equal(msgs[5], ' y: array([inf, 1.])')
547+
534548
def test_subclass_that_cannot_be_bool(self):
535549
# While we cannot guarantee testing functions will always work for
536550
# subclasses, the tests should ideally rely only on subclasses having
@@ -1115,7 +1129,7 @@ def test_simple(self):
11151129

11161130
assert_raises(AssertionError,
11171131
lambda: assert_string_equal("foo", "hello"))
1118-
1132+
11191133
def test_regex(self):
11201134
assert_string_equal("a+*b", "a+*b")
11211135

0 commit comments

Comments
 (0)
0