10000 MNT: convert runtime imports to use single-initialization · numpy/numpy@f587879 · GitHub
[go: up one dir, main page]

Skip to content

Commit f587879

Browse files
committed
MNT: convert runtime imports to use single-initialization
1 parent 2581b5b commit f587879

20 files changed

+204
-153
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ src_multiarray_umath_common = [
10381038
'src/common/mem_overlap.c',
10391039
'src/common/npy_argparse.c',
10401040
'src/common/npy_hashtable.c',
1041+
'src/common/npy_import.c',
10411042
'src/common/npy_longdouble.c',
10421043
'src/common/ucsnarrow.c',
10431044
'src/common/ufunc_override.c',

numpy/_core/src/common/npy_ctypes.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ npy_ctypes_check(PyTypeObject *obj)
2121
PyObject *ret_obj;
2222
int ret;
2323

24-
npy_cache_import("numpy._core._internal", "npy_ctypes_check",
25-
&npy_thread_unsafe_state.npy_ctypes_check);
26-
if (npy_thread_unsafe_state.npy_ctypes_check == NULL) {
24+
if (npy_cache_import_runtime(
25+
"numpy._core._internal", "npy_ctypes_check",
26+
&npy_runtime_imports.npy_ctypes_check) == -1) {
2727
goto fail;
2828
}
2929

30-
ret_obj = PyObject_CallFunctionObjArgs(npy_thread_unsafe_state.npy_ctypes_check,
31-
(PyObject *)obj, NULL);
30+
ret_obj = PyObject_CallFunctionObjArgs(
31+
npy_runtime_imports.npy_ctypes_check.obj, (PyObject *)obj, NULL);
3232
if (ret_obj == NULL) {
3333
goto fail;
3434
}

numpy/_core/src/common/npy_import.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
2+
#define _MULTIARRAYMODULE
3+
4+
#include "numpy/ndarraytypes.h"
5+
#include "npy_import.h"
6+
#include "npy_atomic.h"
7+
8+
9+
NPY_VISIBILITY_HIDDEN npy_runtime_imports_struct npy_runtime_imports;
10+
11+
NPY_NO_EXPORT int
12+
init_import_mutex(void) {
13+
npy_runtime_imports.import_mutex = PyThread_allocate_lock();
14+
if (npy_runtime_imports.import_mutex == NULL) {
15+
PyErr_NoMemory();
16+
return -1;
17+
}
18+
return 0;
19+
}

numpy/_core/src/common/npy_import.h

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,80 @@
33

44
#include <Python.h>
55

6-
/*! \brief Fetch and cache Python function.
6+
#include "numpy/npy_common.h"
7+
#include "npy_atomic.h"
8+
9+
/*
10+
* Holds a cached PyObject where the cache is initialized via a
11+
* runtime import. The cache is only filled once.
12+
*/
13+
14+
typedef struct npy_runtime_import {
15+
npy_uint8 initialized;
16+
PyObject *obj;
17+
} npy_runtime_import;
18+
19+
/*
20+
* Cached references to objects obtained via an import. All of these are
21+
* can be initialized at any time by npy_cache_import_runtime.
22+
*/
23+
typedef struct npy_runtime_imports_struct {
24+
PyThread_type_lock import_mutex;
25+
npy_runtime_import _add_dtype_helper;
26+
npy_runtime_import _all;
27+
npy_runtime_import _amax;
28+
npy_runtime_import _amin;
29+
npy_runtime_import _any;
30+
npy_runtime_import array_function_errmsg_formatter;
31+
npy_runtime_import array_ufunc_errmsg_formatter;
32+
npy_runtime_import _clip;
33+
npy_runtime_import _commastring;
34+
npy_runtime_import _convert_to_stringdtype_kwargs;
35+
npy_runtime_import _default_array_repr;
36+
npy_runtime_import _default_array_str;
37+
npy_runtime_import _dump;
38+
npy_runtime_import _dumps;
39+
npy_runtime_import _getfield_is_safe;
40+
npy_runtime_import internal_gcd_func;
41+
npy_runtime_import _mean;
42+
npy_runtime_import NO_NEP50_WARNING;
43+
npy_runtime_import npy_ctypes_check;
44+
npy_runtime_import numpy_matrix;
45+
npy_runtime_import _prod;
46+
npy_runtime_import _promote_fields;
47+
npy_runtime_import _std;
48+
npy_runtime_import _sum;
49+
npy_runtime_import _ufunc_doc_signature_formatter;
50+
npy_runtime_import _var;
51+
npy_runtime_import _view_is_safe;
52+
npy_runtime_import _void_scalar_to_string;
53+
} npy_runtime_imports_struct;
54+
55+
NPY_VISIBILITY_HIDDEN extern npy_runtime_imports_struct npy_runtime_imports;
56+
57+
/*! \brief Import a Python object.
58+
59+
* This function imports the Python function specified by
60+
* \a module and \a function, increments its reference count, and returns
61+
* the result. On error, returns NULL.
62+
*
63+
* @param module Absolute module name.
64+
* @param attr module attribute to cache.
65+
*/
66+
static inline PyObject*
67+
npy_import(const char *module, const char *attr)
68+
{
69+
PyObject *ret = NULL;
70+
PyObject *mod = PyImport_ImportModule(module);
71+
72+
if (mod != NULL) {
73+
ret = PyObject_GetAttrString(mod, attr);
74+
Py_DECREF(mod);
75+
}
76+
return ret;
77+
}
78+
79+
/*! \brief Fetch and cache Python object at runtime.
780
*
881
* Import a Python function and cache it for use. The function checks if
982
* cache is NULL, and if not NULL imports the Python function specified by
@@ -16,17 +89,28 @@
1689
* @param attr module attribute to cache.
1790
* @param cache Storage location for imported function.
1891
*/
19-
static inline void
20-
npy_cache_import(const char *module, const char *attr, PyObject **cache)
21-
{
22-
if (NPY_UNLIKELY(*cache == NULL)) {
23-
PyObject *mod = PyImport_ImportModule(module);
24-
25-
if (mod != NULL) {
26-
*cache = PyObject_GetAttrString(mod, attr);
27-
Py_DECREF(mod);
92+
static inline int
93+
npy_cache_import_runtime(const char *module, const char *attr, npy_runtime_import *cache) {
94+
if (cache->initialized) {
95+
return 0;
96+
}
97+
else {
98+
if (!npy_atomic_load_uint8(&cache->initialized)) {
99+
PyThread_acquire_lock(npy_runtime_imports.import_mutex, WAIT_LOCK);
100+
if (!cache->initialized) {
101+
cache->obj = npy_import(module, attr);
102+
cache->initialized = 1;
103+
}
104+
PyThread_release_lock(npy_runtime_imports.import_mutex);
28105
}
29106
}
107+
if (cache->obj == NULL) {
108+
return -1;
109+
}
110+
return 0;
30111
}
31112

113+
NPY_NO_EXPORT int
114+
init_import_mutex(void);
115+
32116
#endif /* NUMPY_CORE_SRC_COMMON_NPY_IMPORT_H_ */

numpy/_core/src/multiarray/arrayfunction_override.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,12 @@ static void
232232
set_no_matching_types_error(PyObject *public_api, PyObject *types)
233233
{
234234
/* No acceptable override found, raise TypeError. */
235-
npy_cache_import("numpy._core._internal",
236-
"array_function_errmsg_formatter",
237-
&npy_thread_unsafe_state.array_function_errmsg_formatter);
238-
if (npy_thread_unsafe_state.array_function_errmsg_formatter != NULL) {
235+
if (npy_cache_import_runtime(
236+
"numpy._core._internal",
237+
"array_function_errmsg_formatter",
238+
&npy_runtime_imports.array_function_errmsg_formatter) == 0) {
239239
PyObject *errmsg = PyObject_CallFunctionObjArgs(
240-
npy_thread_unsafe_state.array_function_errmsg_formatter,
240+
npy_runtime_imports.array_function_errmsg_formatter.obj,
241241
public_api, types, NULL);
242242
if (errmsg != NULL) {
243243
PyErr_SetObject(PyExc_TypeError, errmsg);

numpy/_core/src/multiarray/convert_datatype.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,14 @@ npy_give_promotion_warnings(void)
8383
{
8484
PyObject *val;
8585

86-
npy_cache_import(
86+
if (npy_cache_import_runtime(
8787
"numpy._core._ufunc_config", "NO_NEP50_WARNING",
88-
&npy_thread_unsafe_state.NO_NEP50_WARNING);
89-
if (npy_thread_unsafe_state.NO_NEP50_WARNING == NULL) {
88+
&npy_runtime_imports.NO_NEP50_WARNING) == -1) {
9089
PyErr_WriteUnraisable(NULL);
9190
return 1;
9291
}
9392

94-
if (PyContextVar_Get(npy_thread_unsafe_state.NO_NEP50_WARNING,
93+
if (PyContextVar_Get(npy_runtime_imports.NO_NEP50_WARNING.obj,
9594
Py_False, &val) < 0) {
9695
/* Errors should not really happen, but if it does assume we warn. */
9796
PyErr_WriteUnraisable(NULL);

numpy/_core/src/multiarray/descriptor.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -726,12 +726,12 @@ _convert_from_commastring(PyObject *obj, int align)
726726
PyObject *parsed;
727727
PyArray_Descr *res;
728728
assert(PyUnicode_Check(obj));
729-
npy_cache_import("numpy._core._internal", "_commastring",
730-
&npy_thread_unsafe_state._commastring);
731-
if (npy_thread_unsafe_state._commastring == NULL) {
729+
if (npy_cache_import_runtime(
730+
"numpy._core._internal", "_commastring",
731+
&npy_runtime_imports._commastring) == -1) {
732732
return NULL;
733733
}
734-
parsed = PyObject_CallOneArg(npy_thread_unsafe_state._commastring, obj);
734+
parsed = PyObject_CallOneArg(npy_runtime_imports._commastring.obj, obj);
735735
if (parsed == NULL) {
736736
return NULL;
737737
}

numpy/_core/src/multiarray/dtypemeta.c

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -766,13 +766,13 @@ void_common_instance(_PyArray_LegacyDescr *descr1, _PyArray_LegacyDescr *descr2)
766766

767767
if (descr1->names != NULL && descr2->names != NULL) {
768768
/* If both have fields promoting individual fields may be possible */
769-
npy_cache_import("numpy._core._internal", "_promote_fields",
770-
&npy_thread_unsafe_state._promote_fields);
771-
if (npy_thread_unsafe_state._promote_fields == NULL) {
769+
if (npy_cache_import_runtime(
770+
"numpy._core._internal", "_promote_fields",
771+
&npy_runtime_imports._promote_fields) == -1) {
772772
return NULL;
773773
}
774774
PyObject *result = PyObject_CallFunctionObjArgs(
775-
npy_thread_unsafe_state._promote_fields,
775+
npy_runtime_imports._promote_fields.obj,
776776
descr1, descr2, NULL);
777777
if (result == NULL) {
778778
return NULL;
@@ -1240,14 +1240,13 @@ dtypemeta_wrap_legacy_descriptor(
12401240

12411241
/* And it to the types submodule if it is a builtin dtype */
12421242
if (!PyTypeNum_ISUSERDEF(descr->type_num)) {
1243-
npy_cache_import("numpy.dtypes", "_add_dtype_helper",
1244-
&npy_thread_unsafe_state._add_dtype_helper);
1245-
if (npy_thread_unsafe_state._add_dtype_helper == NULL) {
1243+
if (npy_cache_import_runtime("numpy.dtypes", "_add_dtype_helper",
1244+
&npy_runtime_imports._add_dtype_helper) == -1) {
12461245
return -1;
12471246
}
12481247

12491248
if (PyObject_CallFunction(
1250-
npy_thread_unsafe_state._add_dtype_helper,
1249+
npy_runtime_imports._add_dtype_helper.obj,
12511250
"Os", (PyObject *)dtype_class, alias) == NULL) {
12521251
return -1;
12531252
}

numpy/_core/src/multiarray/getset.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg, void *NPY_UNUSED(ignored))
388388
if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(newtype)) {
389389
PyObject *safe;
390390

391-
npy_cache_import("numpy._core._internal", "_view_is_safe",
392-
&npy_thread_unsafe_state._view_is_safe);
393-
if (npy_thread_unsafe_state._view_is_safe == NULL) {
391+
if (npy_cache_import_runtime(
392+
"numpy._core._internal", "_view_is_safe",
393+
&npy_runtime_imports._view_is_safe) == -1) {
394394
goto fail;
395395
}
396396

397-
safe = PyObject_CallFunction(npy_thread_unsafe_state._view_is_safe,
397+
safe = PyObject_CallFunction(npy_runtime_imports._view_is_safe.obj,
398398
"OO", PyArray_DESCR(self), newtype);
399399
if (safe == NULL) {
400400
goto fail;

numpy/_core/src/multiarray/methods.c

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,11 @@ npy_forward_method(
113113
* be correct.
114114
*/
115115
#define NPY_FORWARD_NDARRAY_METHOD(name) \
116-
npy_cache_import( \
117-
"numpy._core._methods", #name, \
118-
&npy_thread_unsafe_state.name); \
119-
if (npy_thread_unsafe_state.name == NULL) { \
116+
if (npy_cache_import_runtime("numpy._core._methods", #name, \
117+
&npy_runtime_imports.name) == -1) { \
120118
return NULL; \
121119
} \
122-
return npy_forward_method(npy_thread_unsafe_state.name, \
120+
return npy_forward_method(npy_runtime_imports.name.obj, \
123121
(PyObject *)self, args, len_args, kwnames)
124122

125123

@@ -406,15 +404,15 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
406404

407405
/* check that we are not reinterpreting memory containing Objects. */
408406
if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(typed)) {
409-
npy_cache_import("numpy._core._internal", "_getfield_is_safe",
410-
&npy_thread_unsafe_state._getfield_is_safe);
411-
if (npy_thread_unsafe_state._getfield_is_safe == NULL) {
407+
if (npy_cache_import_runtime(
408+
"numpy._core._internal", "_getfield_is_safe",
409+
&npy_runtime_imports._getfield_is_safe) == -1) {
412410
Py_DECREF(typed);
413411
return NULL;
414412
}
415413

416414
/* only returns True or raises */
417-
safe = PyObject_CallFunction(npy_thread_unsafe_state._getfield_is_safe,
415+
safe = PyObject_CallFunction(npy_runtime_imports._getfield_is_safe.obj,
418416
"OOi", PyArray_DESCR(self),
419417
typed, offset);
420418
if (safe == NULL) {
@@ -2248,18 +2246,19 @@ NPY_NO_EXPORT int
22482246
PyArray_Dump(PyObject *self, PyObject *file, int protocol)
22492247
{
22502248
PyObject *ret;
2251-
npy_cache_import("numpy._core._methods", "_dump",
2252-
&npy_thread_unsafe_state._dump);
2253-
if (npy_thread_unsafe_state._dump == NULL) {
2249+
if (npy_cache_import_runtime(
2250+
"numpy._core._methods", "_dump",
2251+
&npy_runtime_imports._dump) == -1) {
22542252
return -1;
22552253
}
2254+
22562255
if (protocol < 0) {
22572256
ret = PyObject_CallFunction(
2258-
npy_thread_unsafe_state._dump, "OO", self, file);
2257+
npy_runtime_imports._dump.obj, "OO", self, file);
22592258
}
22602259
else {
22612260
ret = PyObject_CallFunction(
2262-
npy_thread_unsafe_state._dump, "OOi", self, file, protocol);
2261+
npy_runtime_imports._dump.obj, "OOi", self, file, protocol);
22632262
}
22642263
if (ret == NULL) {
22652264
return -1;
@@ -2272,17 +2271,16 @@ PyArray_Dump(PyObject *self, PyObject *file, int protocol)
22722271
NPY_NO_EXPORT PyObject *
22732272
PyArray_Dumps(PyObject *self, int protocol)
22742273
{
2275-
npy_cache_import("numpy._core._methods", "_dumps",
2276-
&npy_thread_unsafe_state._dumps);
2277-
if (npy_thread_unsafe_state._dumps == NULL) {
2274+
if (npy_cache_import_runtime("numpy._core._methods", "_dumps",
2275+
&npy_runtime_imports._dumps) == -1) {
22782276
return NULL;
22792277
}
22802278
if (protocol < 0) {
2281-
return PyObject_CallFunction(npy_thread_unsafe_state._dumps, "O", self);
2279+
return PyObject_CallFunction(npy_runtime_imports._dumps.obj, "O", self);
22822280
}
22832281
else {
22842282
return PyObject_CallFunction(
2285-
npy_thread_unsafe_state._dumps, "Oi", self, protocol);
2283+
npy_runtime_imports._dumps.obj, "Oi", self, protocol);
22862284
}
22872285
}
22882286

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4842,6 +4842,10 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
48424842
goto err;
48434843
}
48444844

4845+
if (init_import_mutex() < 0) {
4846+
goto err;
4847+
}
4848+
48454849
if (init_extobj() < 0) {
48464850
goto err;
48474851
}
@@ -5067,14 +5071,15 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
50675071
* init_string_dtype() but that needs to happen after
50685072
* the legacy dtypemeta classes are available.
50695073
*/
5070-
npy_cache_import("numpy.dtypes", "_add_dtype_helper",
5071-
&npy_thread_unsafe_state._add_dtype_helper);
5072-
if (npy_thread_unsafe_state._add_dtype_helper == NULL) {
5074+
5075+
if (npy_cache_import_runtime(
5076+
"numpy.dtypes", "_add_dtype_helper",
5077+
&npy_runtime_imports._add_dtype_helper) == -1) {
50735078
goto err;
50745079
}
50755080

50765081
if (PyObject_CallFunction(
5077-
npy_thread_unsafe_state._add_dtype_helper,
5082+
npy_runtime_imports._add_dtype_helper.obj,
50785083
"Os", (PyObject *)&PyArray_StringDType, NULL) == NULL) {
50795084
goto err;
50805085
}

0 commit comments

Comments
 (0)
0