@@ -62,6 +62,29 @@ def get_module(tmp_path):
62
62
// PyArray_BASE(PyArrayObject *)args) = NULL;
63
63
Py_RETURN_NONE;
64
64
""" ),
65
+ ("get_array_with_base" , "METH_NOARGS" , """
66
+ char *buf = (char *)malloc(20);
67
+ npy_intp dims[1];
68
+ dims[0] = 20;
69
+ PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8);
70
+ PyObject *arr = PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims,
71
+ NULL, buf,
72
+ NPY_ARRAY_WRITEABLE, NULL);
73
+ if (arr == NULL) return NULL;
74
+ PyObject *obj = PyCapsule_New(buf, "buf capsule",
75
+ (PyCapsule_Destructor)&warn_on_free);
76
+ if (obj == NULL) {
77
+ Py_DECREF(arr);
78
+ return NULL;
79
+ }
80
+ if (PyArray_SetBaseObject((PyArrayObject *)arr, obj) == -1) {
81
+ Py_DECREF(arr);
82
+ Py_DECREF(obj);
83
+ return NULL;
84
+ }
85
+ return arr;
86
+
87
+ """ ),
65
88
]
66
89
prologue = '''
67
90
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -164,6 +187,12 @@ def get_module(tmp_path):
<
8000
code>164 187
shift_free /* free */
165
188
}
166
189
};
190
+ void warn_on_free(void *capsule) {
191
+ PyErr_WarnEx(PyExc_UserWarning, "in warn_on_free", 1);
192
+ void * obj = PyCapsule_GetPointer(capsule,
193
+ PyCapsule_GetName(capsule));
194
+ free(obj);
195
+ };
167
196
'''
168
197
more_init = "import_array();"
169
198
try :
@@ -335,15 +364,18 @@ def test_switch_owner(get_module):
335
364
oldval = os .environ .get ('NUMPY_WARN_IF_NO_MEM_POLICY' , None )
336
365
os .environ ['NUMPY_WARN_IF_NO_MEM_POLICY' ] = "1"
337
366
try :
338
- with warnings .catch_warnings ():
339
- warnings .filterwarnings ('always' )
340
- # The policy should be NULL, so we have to assume we can call
341
- # "free"
342
- with assert_warns (RuntimeWarning ) as w :
343
- del a
344
- gc .collect ()
367
+ # The policy should be NULL, so we have to assume we can call
368
+ # "free"
369
+ with assert_warns (RuntimeWarning ) as w :
370
+ del a
371
+ gc .collect ()
345
372
finally :
346
373
if oldval is None :
347
374
os .environ .pop ('NUMPY_WARN_IF_NO_MEM_POLICY' )
348
375
else :
349
376
os .environ ['NUMPY_WARN_IF_NO_MEM_POLICY' ] = oldval
377
+
378
+ a = get_module .get_array_with_base ()
379
+ with pytest .warns (UserWarning , match = 'warn_on_free' ):
380
+ del a
381
+ gc .collect ()
0 commit comments