8000 add test, fix example code · numpy/numpy@865d168 · GitHub
[go: up one dir, main page]

Skip to content

Commit 865d168

Browse files
committed
add test, fix example code
1 parent 27106b0 commit 865d168

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

doc/source/reference/c-api/data_memory.rst

Lines changed: 7 additions & 3 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ What happens when deallocating if there is no policy set
124124
--------------------------------------------------------
125125
126126
A rare but useful technique is to allocate a buffer outside NumPy, use
127-
:c:function:`PyArray_NewFromDescr` to wrap the buffer in a ``ndarray``, then switch
127+
:c:func:`PyArray_NewFromDescr` to wrap the buffer in a ``ndarray``, then switch
128128
the ``OWNDATA`` flag to true. When the ``ndarray`` is released, the
129129
appropriate function from the ``ndarray``'s ``PyDataMem_Handler`` should be
130130
called to free the buffer. But the ``PyDataMem_Handler`` field was never set,
@@ -138,15 +138,19 @@ A better technique would be to use a ``PyCapsule`` as a base object:
138138
.. code-block:: c
139139
140140
/* define a PyCapsule_Destructor, using the correct deallocator for buff */
141-
void free_wrap(PyObject *obj){ free(obj); };
141+
void free_wrap(void *capsule){
142+
void * obj = PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule));
143+
free(obj);
144+
};
142145
143146
/* then inside the function that creates arr from buff */
144147
...
145148
arr = PyArray_NewFromDescr(... buf, ...);
146149
if (arr == NULL) {
147150
return NULL;
148151
}
149-
capsule = PyCapsule_New(buf, "my_wrapped_buffer", free_wrap);
152+
capsule = PyCapsule_New(buf, "my_wrapped_buffer",
153+
(PyCapsule_Destructor)&free_wrap);
150154
if (PyArray_SetBaseObject(arr, capsule) == -1) {
151155
Py_DECREF(arr);
152156
return NULL;

numpy/core/tests/test_mem_policy.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ def get_module(tmp_path):
6262
// PyArray_BASE(PyArrayObject *)args) = NULL;
6363
Py_RETURN_NONE;
6464
"""),
65+
("get_array_with_base", "METH_NOARGS", """
66+
char *buf = (char *)malloc(20);
67+
npy_intp dims[1];
68+
dims[0] = 20;
69+
PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8);
70+
PyObject *arr = PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims,
71+
NULL, buf,
72+
NPY_ARRAY_WRITEABLE, NULL);
73+
if (arr == NULL) return NULL;
74+
PyObject *obj = PyCapsule_New(buf, "buf capsule",
75+
(PyCapsule_Destructor)&warn_on_free);
76+
if (obj == NULL) {
77+
Py_DECREF(arr);
78+
return NULL;
79+
}
80+
if (PyArray_SetBaseObject((PyArrayObject *)arr, obj) == -1) {
81+
Py_DECREF(arr);
82+
Py_DECREF(obj);
83+
return NULL;
84+
}
85+
return arr;
86+
87+
"""),
6588
]
6689
prologue = '''
6790
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -164,6 +187,12 @@ def get_module(tmp_path):
< 8000 code>164187
shift_free /* free */
165188
}
166189
};
190+
void warn_on_free(void *capsule) {
191+
PyErr_WarnEx(PyExc_UserWarning, "in warn_on_free", 1);
192+
void * obj = PyCapsule_GetPointer(capsule,
193+
PyCapsule_GetName(capsule));
194+
free(obj);
195+
};
167196
'''
168197
more_init = "import_array();"
169198
try:
@@ -335,15 +364,18 @@ def test_switch_owner(get_module):
335364
oldval = os.environ.get('NUMPY_WARN_IF_NO_MEM_POLICY', None)
336365
os.environ['NUMPY_WARN_IF_NO_MEM_POLICY'] = "1"
337366
try:
338-
with warnings.catch_warnings():
339-
warnings.filterwarnings('always')
340-
# The policy should be NULL, so we have to assume we can call
341-
# "free"
342-
with assert_warns(RuntimeWarning) as w:
343-
del a
344-
gc.collect()
367+
# The policy should be NULL, so we have to assume we can call
368+
# "free"
369+
with assert_warns(RuntimeWarning) as w:
370+
del a
371+
gc.collect()
345372
finally:
346373
if oldval is None:
347374
os.environ.pop('NUMPY_WARN_IF_NO_MEM_POLICY')
348375
else:
349376
os.environ['NUMPY_WARN_IF_NO_MEM_POLICY'] = oldval
377+
378+
a = get_module.get_array_with_base()
379+
with pytest.warns(UserWarning, match='warn_on_free'):
380+
del a
381+
gc.collect()

0 commit comments

Comments
 (0)
0