8000 ENH: fix and test for blindly taking ownership of data · numpy/numpy@442b0e1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 442b0e1

Browse files
committed
ENH: fix and test for blindly taking ownership of data
1 parent 522c368 commit 442b0e1

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

numpy/core/src/multiarray/arrayobject.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,17 @@ array_dealloc(PyArrayObject *self)
501501
if (nbytes == 0) {
502502
nbytes = fa->descr->elsize ? fa->descr->elsize : 1;
503503
}
504-
PyDataMem_UserFREE(fa->data, nbytes, fa->mem_handler);
505-
Py_DECREF(fa->mem_handler);
504+
if (fa->mem_handler == NULL) {
505+
char const * msg = "Trying to dealloc data, but a memory policy "
506+
"is not set. If you take ownership of the data, you must "
507+
"also set a memory policy.";
508+
WARN_IN_DEALLOC(PyExc_RuntimeWarning, msg);
509+
// Guess at malloc/free ???
510+
free(fa->data);
511+
} else {
512+
PyDataMem_UserFREE(fa->data, nbytes, fa->mem_handler);
513+
Py_DECREF(fa->mem_handler);
514+
}
506515
}
507516

508517
/* must match allocation in PyArray_NewFromDescr */

numpy/core/tests/test_mem_policy.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
2+
import gc
23
import pytest
34
import numpy as np
45
import threading
5-
from numpy.testing import extbuild
6+
import warnings
7+
from numpy.testing import extbuild, assert_warns
68
import sys
79

810

@@ -40,6 +42,25 @@ def get_module(tmp_path):
4042
Py_DECREF(old);
4143
Py_RETURN_NONE;
4244
"""),
45+
("get_array", "METH_NOARGS", """
46+
char *buf = (char *)malloc(20);
47+
npy_intp dims[1];
48+
dims[0] = 20;
49+
PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8);
50+
return PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims, NULL,
51+
buf, NPY_ARRAY_WRITEABLE, NULL);
52+
"""),
53+
("set_own", "METH_O", """
54+
if (!PyArray_Check(args)) {
55+
PyErr_SetString(PyExc_ValueError,
56+
"need an ndarray");
57+
return NULL;
58+
}
59+
PyArray_ENABLEFLAGS((PyArrayObject*)args, NPY_ARRAY_OWNDATA);
60+
// Maybe try this too?
61+
// PyArray_BASE(PyArrayObject *)args) = NULL;
62+
Py_RETURN_NONE;
63+
"""),
4364
]
4465
prologue = '''
4566
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -305,3 +326,15 @@ def test_new_policy(get_module):
305326

306327
c = np.arange(10)
307328
assert np.core.multiarray.get_handler_name(c) == orig_policy_name
329+
330+
def test_switch_owner(get_module):
331+
a = get_module.get_array()
332+
assert np.core.multiarray.get_handler_name(a) is None
333+
get_module.set_own(a)
334+
with warnings.catch_warnings():
335+
warnings.filterwarnings('always')
336+
# The policy should be NULL, so we have to assume we can call "free"
337+
with assert_warns(RuntimeWarning) as w:
338+
del a
339+
gc.collect()
340+
print(w)

0 commit comments

Comments
 (0)
0