|
1 | 1 | import asyncio
|
| 2 | +import gc |
2 | 3 | import pytest
|
3 | 4 | import numpy as np
|
4 | 5 | import threading
|
5 |
| -from numpy.testing import extbuild |
| 6 | +import warnings |
| 7 | +from numpy.testing import extbuild, assert_warns |
6 | 8 | import sys
|
7 | 9 |
|
8 | 10 |
|
@@ -40,6 +42,25 @@ def get_module(tmp_path):
|
40 | 42 | Py_DECREF(old);
|
41 | 43 | Py_RETURN_NONE;
|
42 | 44 | """),
|
| 45 | + ("get_array", "METH_NOARGS", """ |
| 46 | + char *buf = (char *)malloc(20); |
| 47 | + npy_intp dims[1]; |
| 48 | + dims[0] = 20; |
| 49 | + PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8); |
| 50 | + return PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims, NULL, |
| 51 | + buf, NPY_ARRAY_WRITEABLE, NULL); |
| 52 | + """), |
| 53 | + ("set_own", "METH_O", """ |
| 54 | + if (!PyArray_Check(args)) { |
| 55 | + PyErr_SetString(PyExc_ValueError, |
| 56 | + "need an ndarray"); |
| 57 | + return NULL; |
| 58 | + } |
| 59 | + PyArray_ENABLEFLAGS((PyArrayObject*)args, NPY_ARRAY_OWNDATA); |
| 60 | + // Maybe try this too? |
| 61 | + // PyArray_BASE(PyArrayObject *)args) = NULL; |
| 62 | + Py_RETURN_NONE; |
| 63 | + """), |
43 | 64 | ]
|
44 | 65 | prologue = '''
|
45 | 66 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
@@ -305,3 +326,15 @@ def test_new_policy(get_module):
|
305 | 326 |
|
306 | 327 | c = np.arange(10)
|
307 | 328 | assert np.core.multiarray.get_handler_name(c) == orig_policy_name
|
| 329 | + |
| 330 | +def test_switch_owner(get_module): |
| 331 | + a = get_module.get_array() |
| 332 | + assert np.core.multiarray.get_handler_name(a) is None |
| 333 | + get_module.set_own(a) |
| 334 | + with warnings.catch_warnings(): |
| 335 | + warnings.filterwarnings('always') |
| 336 | + # The policy should be NULL, so we have to assume we can call "free" |
| 337 | + with assert_warns(RuntimeWarning) as w: |
| 338 | + del a |
| 339 | + gc.collect() |
| 340 | + print(w) |
0 commit comments