8000 BUG: core: handle sub-arrays in dtype comparisons · numpy/numpy@5012504 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5012504

Browse files
mwiebepv
authored andcommitted
BUG: core: handle sub-arrays in dtype comparisons
1 parent 33b3e60 commit 5012504

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,37 @@ _equivalent_units(PyObject *meta1, PyObject *meta2)
13611361
&& (data1->events == data2->events));
13621362
}
13631363

1364+
/*
1365+
* Compare the subarray data for two types.
1366+
* Return 1 if they are the same, 0 if not.
1367+
*/
1368+
static int
1369+
_equivalent_subarrays(PyArray_ArrayDescr *sub1, PyArray_ArrayDescr *sub2)
1370+
{
1371+
int val;
1372+
1373+
if (sub1 == sub2) {
1374+
return 1;
1375+
1376+
}
1377+
if (sub1 == NULL || sub2 == NULL) {
1378+
return 0;
1379+
}
1380+
1381+
#if defined(NPY_PY3K)
1382+
val = PyObject_RichCompareBool(sub1->shape, sub2->shape, Py_EQ);
1383+
if (val != 1 || PyErr_Occurred()) {
1384+
#else
1385+
val = PyObject_Compare(sub1->shape, sub2->shape);
1386+
if (val != 0 || PyErr_Occurred()) {
1387+
#endif
1388+
PyErr_Clear();
1389+
return 0;
1390+
}
1391+
1392+
return PyArray_EquivTypes(sub1->base, sub2->base);
1393+
}
1394+
13641395

13651396
/*NUMPY_API
13661397
*
@@ -1381,6 +1412,10 @@ PyArray_EquivTypes(PyArray_Descr *typ1, PyArray_Descr *typ2)
13811412
if (PyArray_ISNBO(typ1->byteorder) != PyArray_ISNBO(typ2->byteorder)) {
13821413
return FALSE;
13831414
}
1415+
if (typ1->subarray || typ2->subarray) {
1416+
return ((typenum1 == typenum2)
1417+
&& _equivalent_subarrays(typ1->subarray, typ2->subarray));
1418+
}
13841419
if (typenum1 == PyArray_VOID
13851420
|| typenum2 == PyArray_VOID) {
13861421
return ((typenum1 == typenum2)
@@ -1874,7 +1909,7 @@ array_arange(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kws) {
18741909
}
18751910
range = PyArray_ArangeObj(o_start, o_stop, o_step, typecode);
18761911
Py_XDECREF(typecode);
1877-
return range;
1912+
return range;
18781913
}
18791914

18801915
/*NUMPY_API

numpy/core/tests/test_dtype.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,33 @@ def test_not_lists(self):
5555
self.assertRaises(TypeError, np.dtype,
5656
dict(names=['A', 'B'], formats=set(['f8', 'i4'])))
5757

58+
class TestShape(TestCase):
59+
def test_equal(self):
60+
"""Test some data types that are equal"""
61+
self.assertEqual(np.dtype('f8'), np.dtype(('f8',tuple())))
62+
self.assertEqual(np.dtype('f8'), np.dtype(('f8',1)))
63+
self.assertEqual(np.dtype((np.int,2)), np.dtype((np.int,(2,))))
64+
self.assertEqual(np.dtype(('<f4',(3,2))), np.dtype(('<f4',(3,2))))
65+
d = ([('a','f4',(1,2)),('b','f8',(3,1))],(3,2))
66+
self.assertEqual(np.dtype(d), np.dtype(d))
67+
68+
def test_simple(self):
69+
"""Test some simple cases that shouldn't be equal"""
70+
self.assertNotEqual(np.dtype('f8'), np.dtype(('f8',(1,))))
71+
self.assertNotEqual(np.dtype(('f8',(1,))), np.dtype(('f8',(1,1))))
72+
self.assertNotEqual(np.dtype(('f4',(3,2))), np.dtype(('f4',(2,3))))
73+
74+
def test_monster(self):
75+
"""Test some more complicated cases that shouldn't be equal"""
76+
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
77+
np.dtype(([('a','f4',(1,2)), ('b','f8',(1,3))],(2,2))))
78+
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
79+
np.dtype(([('a','f4',(2,1)), ('b','i8',(1,3))],(2,2))))
80+
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
81+
np.dtype(([('e','f8',(1,3)), ('d','f4',(2,1))],(2,2))))
82+
self.assertNotEqual(np.dtype(([('a',[('a','i4',6)],(2,1)), ('b','f8',(1,3))],(2,2))),
83+
np.dtype(([('a',[('a','u4',6)],(2,1)), ('b','f8',(1,3))],(2,2))))
84+
5885
class TestSubarray(TestCase):
5986
def test_single_subarray(self):
6087
a = np.dtype((np.int, (2)))

0 commit comments

Comments
 (0)
0