8000 BUG: Update base-traversal algorithm for array sub-classes so as to s… · numpy/numpy@a9c7b38 · GitHub
[go: up one dir, main page]

Skip to content

Commit a9c7b38

Browse files
teoliphantcertik
authored andcommitted
BUG: Update base-traversal algorithm for array sub-classes so as to stop the base-traversal when the new base would not be an instance of the sub-class.
1 parent 5ae12ea commit a9c7b38

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

numpy/core/src/multiarray/arrayobject.c

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,29 +143,38 @@ PyArray_SetBaseObject(PyArrayObject *arr, PyObject *obj)
143143
}
144144

145145
/*
146-
* Don't allow chains of views, always set the base
147-
* to the owner of the data. That is, either the first object
148-
* which isn't an array, or the first object which owns
149-
* its own data.
146+
* Don't allow infinite chains of views, always set the base
147+
* to the first owner of the data.
148+
* That is, either the first object which isn't an array,
149+
* or the first object which owns its own data.
150150
*/
151+
151152
while (PyArray_Check(obj) && (PyObject *)arr != obj) {
152153
PyArrayObject *obj_arr = (PyArrayObject *)obj;
153154
PyObject *tmp;
154155

155156
/* Propagate WARN_ON_WRITE through views. */
156157
if (PyArray_FLAGS(obj_arr) & NPY_ARRAY_WARN_ON_WRITE) {
157158
PyArray_ENABLEFLAGS(arr, NPY_ARRAY_WARN_ON_WRITE);
158-
}
159+
}
159160

160161
/* If this array owns its own data, stop collapsing */
161162
if (PyArray_CHKFLAGS(obj_arr, NPY_ARRAY_OWNDATA)) {
162163
break;
163-
}
164-
/* If there's no base, stop collapsing */
164+
}
165+
165166
tmp = PyArray_BASE(obj_arr);
167+
/* If there's no base, stop collapsing */
166168
if (tmp == NULL) {
167169
break;
168170
}
171+
/* Stop the collapse for array sub-classes if new base
172+
* would not be of the same type.
173+
*/
174+
if (!(PyArray_CheckExact(arr)) & (Py_TYPE(tmp) != Py_TYPE(arr))) {
175+
break;
176+
}
177+
169178

170179
Py_INCREF(tmp);
171180
Py_DECREF(obj);

numpy/core/tests/test_memmap.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,12 @@ def test_slicing_keeps_references(self):
104104
shape=self.shape)
105105
assert fp[:2, :2]._mmap is fp._mmap
106106

107+
def test_view(self):
108+
fp = memmap(self.tmpfp, dtype=self.dtype, shape=self.shape)
109+
new1 = fp.view()
110+
new2 = new1.view()
111+
assert(new1.base is fp)
112+
assert(new2.base is fp)
113+
107114
if __name__ == "__main__":
108115
run_module_suite()

0 commit comments

Comments
 (0)
0