From bfdd6c7c0be11bbe89f110d30034767550ce471f Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Thu, 28 Jan 2016 15:21:58 +0530 Subject: [PATCH] correct behaviour for assert_array_less with inf/nan --- numpy/testing/tests/test_utils.py | 43 +++++++++++++++++++++++++++++-- numpy/testing/utils.py | 11 ++------ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 7de57d40809c..25a9125b1eb5 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -6,12 +6,12 @@ import numpy as np from numpy.testing import ( - assert_equal, assert_array_equal, assert_almost_equal, + assert_equal, assert_array_equal, assert_almost_equal, assert_array_less, assert_array_almost_equal, build_err_msg, raises, assert_raises, assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp, clear_and_catch_warnings, run_module_suite, - assert_string_equal, assert_, tempdir, temppath, + assert_string_equal, assert_, tempdir, temppath, ) import unittest @@ -335,6 +335,45 @@ def test_error_message(self): self.assertEqual(str(e).split('%)\n ')[1], b) +class TestArrayLess(unittest.TestCase): + + def setUp(self): + self._assert_func = assert_array_less + + def test_simple(self): + a = 1 + b = [2, 3, 4] + self._assert_func(a, b) + a = [1, 0, 1] + self._assert_func(a, b) + a = [1] + self.assertRaises(AssertionError, + lambda: self._assert_func(a, b)) + + def test_nan(self): + a = [1, np.nan] + b = [2, np.nan] + self._assert_func(a, b) + a = [np.nan, 1] + self.assertRaises(AssertionError, + lambda: self._assert_func(a, b)) + + def test_inf(self): + a = 2 + b = np.inf + self._assert_func(a, b) + a = [ 0.911, 1.065, 1.325, 1.587] + self._assert_func(a, b) + a.append(np.inf) + self.assertRaises(AssertionError, + lambda: self._assert_func(a, b)) + self._assert_func(-np.inf, np.inf) + a = [1, np.inf] + b = [2, np.inf] + self.assertRaises(AssertionError, + lambda: self._assert_func(a, b)) + + class TestApproxEqual(unittest.TestCase): def setUp(self): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index f2588788df57..699ef5f2c0f9 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -689,20 +689,13 @@ def chk_same_position(x_id, y_id, hasval='nan'): if isnumber(x) and isnumber(y): x_isnan, y_isnan = isnan(x), isnan(y) - x_isinf, y_isinf = isinf(x), isinf(y) # Validate that the special values are in the same place if any(x_isnan) or any(y_isnan): chk_same_position(x_isnan, y_isnan, hasval='nan') - if any(x_isinf) or any(y_isinf): - # Check +inf and -inf separately, since they are different - chk_same_position(x == +inf, y == +inf, hasval='+inf') - chk_same_position(x == -inf, y == -inf, hasval='-inf') # Combine all the special values x_id, y_id = x_isnan, y_isnan - x_id |= x_isinf - y_id |= y_isinf # Only do the comparison if actual values are left if all(x_id): @@ -886,7 +879,7 @@ def compare(x, y): if npany(gisinf(x)) or npany( gisinf(y)): xinfid = gisinf(x) yinfid = gisinf(y) - if not xinfid == yinfid: + if not (xinfid == yinfid).all(): return False # if one item, x and y is +- inf if x.size == y.size == 1: @@ -1863,7 +1856,7 @@ def temppath(*args, **kwargs): parameters are the same as for tempfile.mkstemp and are passed directly to that function. The underlying file is removed when the context is exited, so it should be closed at that time. - + Windows does not allow a temporary file to be opened if it is already open, so the underlying file must be closed after opening before it can be opened again.