8000 MAINT: back printoptions with a true context variable (#26846) · numpy/numpy@b3feb3c · GitHub
[go: up one dir, main page]

Skip to content

Commit b3feb3c

Browse files
ngoldbaummtsokol
andauthored
MAINT: back printoptions with a true context variable (#26846)
This is a re-do of gh-26345, I'm taking over from @mtsokol because this is needed for the free-threaded work. The new _printoptions.py file exists to avoid a circular import during setup of the multiarray module. I'm guessing this adds some overhead to printing. I haven't benchmarked it because it wasn't clear to me: do we care about printing performance? I could certainly add some caching or a way to avoid repeatedly calling get_legacy_print_mode for every printed value. We could also keep the C global we had before but make it thread-local. I just thought it made things conceptually simpler to store all the printoptions state in the context variable. Co-authored-by: Mateusz Sokół <mat646@gmail.com>
1 parent 9f66869 commit b3feb3c

File tree

12 files changed

+299
-98
lines changed

12 files changed

+299
-98
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
The `numpy.printoptions` context manager is now thread and async-safe
2+
---------------------------------------------------------------------
3+
4+
In prior versions of NumPy, the printoptions were defined using a combination
5+
of Python and C global variables. We have refactored so the state is stored in
6+
a python ``ContextVar``, making the context manager thread and async-safe.

numpy/_core/arrayprint.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,17 @@
3535
from .umath import absolute, isinf, isfinite, isnat
3636
from . import multiarray
3737
from .multiarray import (array, dragon4_positional, dragon4_scientific,
38-
datetime_as_string, datetime_data, ndarray,
39-
set_legacy_print_mode)
38+
datetime_as_string, datetime_data, ndarray)
4039
from .fromnumeric import any
4140
from .numeric import concatenate, asarray, errstate
4241
from .numerictypes import (longlong, intc, int_, float64, complex128,
4342
flexible)
4443
from .overrides import array_function_dispatch, set_module
44+
from .printoptions import format_options
4545
import operator
4646
import warnings
4747
import contextlib
4848

49-
_format_options = {
50-
'edgeitems': 3, # repr N leading and trailing items of each dimension
51-
'threshold': 1000, # total items > triggers array summarization
52-
'floatmode': 'maxprec',
53-
'precision': 8, # precision of floating point representations
54-
'suppress': False, # suppress printing small floating values in exp format
55-
'linewidth': 75,
56-
'nanstr': 'nan',
57-
'infstr': 'inf',
58-
'sign': '-',
59-
'formatter': None,
60-
# Internally stored as an int to simplify comparisons; converted from/to
61-
# str/False on the way in/out.
62-
'legacy': sys.maxsize,
63-
'override_repr': None,
64-
}
6549

6650
def _make_options_dict(precision=None, threshold=None, edgeitems=None,
6751
linewidth=None, suppress=None, nanstr=None, infstr=None,
@@ -295,25 +279,21 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
295279
array([ 0. , 1.11, 2.22, ..., 7.78, 8.89, 10. ])
296280
297281
"""
298-
opt = _make_options_dict(precision, threshold, edgeitems, linewidth,
299-
suppress, nanstr, infstr, sign, formatter,
300-
floatmode, legacy, override_repr)
282+
new_opt = _make_options_dict(precision, threshold, edgeitems, linewidth,
283+
suppress, nanstr, infstr, sign, formatter,
284+
floatmode, legacy)
301285
# formatter and override_repr are always reset
302-
opt['formatter'] = formatter
303-
opt['override_repr'] = override_repr
304-
_format_options.update(opt)
305< F438 /code>-
306-
# set the C variable for legacy mode
307-
if _format_options['legacy'] == 113:
308-
set_legacy_print_mode(113)
309-
# reset the sign option in legacy mode to avoid confusion
310-
_format_options['sign'] = '-'
311-
elif _format_options['legacy'] == 121:
312-
set_legacy_print_mode(121)
313-
elif _format_options['legacy'] == 125:
314-
set_legacy_print_mode(125)
315-
elif _format_options['legacy'] == sys.maxsize:
316-
set_legacy_print_mode(0)
286+
new_opt['formatter'] = formatter
287+
new_opt['override_repr'] = override_repr
288+
289+
updated_opt = format_options.get() | new_opt
290+
updated_opt.update(new_opt)
291+
292+
if updated_opt['legacy'] == 113:
293+
updated_opt['sign'] = '-'
294+
295+
token = format_options.set(updated_opt)
296+
return token
317297

318298

319299
@set_module('numpy')
@@ -355,7 +335,7 @@ def get_printoptions():
355335
100
356336
357337
"""
358-
opts = _format_options.copy()
338+
opts = format_options.get().copy()
359339
opts['legacy'] = {
360340
113: '1.13', 121: '1.21', 125: '1.25', sys.maxsize: False,
361341
}[opts['legacy']]
@@ -364,7 +344,7 @@ def get_printoptions():
364344

365345
def _get_legacy_print_mode():
366346
"""Return the legacy print mode as an int."""
367-
return _format_options['legacy']
347+
return format_options.get()['legacy']
368348

369349

370350
@set_module('numpy')
@@ -393,13 +373,12 @@ def printoptions(*args, **kwargs):
393373
--------
394374
set_printoptions, get_printoptions
395375
396-
"""
397-
opts = np.get_printoptions()
376+
"""
377+
token = set_printoptions(*args, **kwargs)
398378
try:
399-
np.set_printoptions(*args, **kwargs)
400-
yield np.get_printoptions()
379+
yield get_printoptions()
401380
finally:
402-
np.set_printoptions(**opts)
381+
format_options.reset(token)
403382

404383

405384
def _leading_trailing(a, edgeitems, index=()):
@@ -757,7 +736,7 @@ def array2string(a, max_line_width=None, precision=None,
757736
overrides = _make_options_dict(precision, threshold, edgeitems,
758737
max_line_width, suppress_small, None, None,
759738
sign, formatter, floatmode, legacy)
760-
options = _format_options.copy()
739+
options = format_options.get().copy()
761740
options.update(overrides)
762741

763742
if options['legacy'] <= 113:
@@ -980,7 +959,6 @@ def __init__(self, data, precision, floatmode, suppress_small, sign=False,
980959
self.sign = sign
981960
self.exp_format = False
982961
self.large_exponent = False
983-
984962
self.fillFormat(data)
985963

986964
def fillFormat(self, data):
@@ -1062,22 +1040,23 @@ def fillFormat(self, data):
10621040
# if there are non-finite values, may need to increase pad_left
10631041
if data.size != finite_vals.size:
10641042
neginf = self.sign != '-' or any(data[isinf(data)] < 0)
1065-
nanlen = len(_format_options['nanstr'])
1066-
inflen = len(_format_options['infstr']) + neginf
10671043
offset = self.pad_right + 1 # +1 for decimal pt
1044+
current_options = format_options.get()
10681045
self.pad_left = max(
1069-
self.pad_left, nanlen - offset, inflen - offset
1046+
self.pad_left, len(current_options['nanstr']) - offset,
1047+
len(current_options['infstr']) + neginf - offset
10701048
)
10711049

10721050
def __call__(self, x):
10731051
if not np.isfinite(x):
10741052
with errstate(invalid='ignore'):
1053+
current_options = format_options.get()
10751054
if np.isnan(x):
10761055
sign = '+' if self.sign == '+' else ''
1077-
ret = sign + _format_options['nanstr']
1056+
ret = sign + current_options['nanstr']
10781057
else: # isinf
10791058
sign = '-' if x < 0 else '+' if self.sign == '+' else ''
1080-
ret = sign + _format_options['infstr']
1059+
ret = sign + current_options['infstr']
10811060
return ' '*(
10821061
self.pad_left + self.pad_right + 1 - len(ret)
10831062
) + ret
@@ -1468,10 +1447,10 @@ def _void_scalar_to_string(x, is_repr=True):
14681447
scalartypes.c.src code, and is placed here because it uses the elementwise
14691448
formatters defined above.
14701449
"""
1471-
options = _format_options.copy()
1450+
options = format_options.get().copy()
14721451

14731452
if options["legacy"] <= 125:
1474-
return StructuredVoidFormat.from_data(array(x), **_format_options)(x)
1453+
return StructuredVoidFormat.from_data(array(x), **options)(x)
14751454

14761455
if options.get('formatter') is None:
14771456
options['formatter'] = {}
@@ -1515,7 +1494,7 @@ def dtype_is_implied(dtype):
15151494
array([1, 2, 3], dtype=int8)
15161495
"""
15171496
dtype = np.dtype(dtype)
1518-
if _format_options['legacy'] <= 113 and dtype.type == np.bool:
1497+
if format_options.get()['legacy'] <= 113 and dtype.type == np.bool:
15191498
return False
15201499

15211500
# not just void types F438 can be structured, and names are not part of the repr
@@ -1565,12 +1544,13 @@ def _array_repr_implementation(
15651544
arr, max_line_width=None, precision=None, suppress_small=None,
15661545
array2string=array2string):
15671546
"""Internal version of array_repr() that allows overriding array2string."""
1568-
override_repr = _format_options["override_repr"]
1547+
current_options = format_options.get()
1548+
override_repr = current_options["override_repr"]
15691549
if override_repr is not None:
15701550
return override_repr(arr)
15711551

15721552
if max_line_width is None:
1573-
max_line_width = _format_options['linewidth']
1553+
max_line_width = current_options['linewidth']
15741554

15751555
if type(arr) is not ndarray:
15761556
class_name = type(arr).__name__
@@ -1582,7 +1562,7 @@ def _array_repr_implementation(
15821562
prefix = class_name + "("
15831563
suffix = ")" if skipdtype else ","
15841564

1585-
if (_format_options['legacy'] <= 113 and
1565+
if (current_options['legacy'] <= 113 and
15861566
arr.shape == () and not arr.dtype.names):
15871567
lst = repr(arr.item())
15881568
elif arr.size > 0 or arr.shape == (0,):
@@ -1603,7 +1583,7 @@ def _array_repr_implementation(
16031583
# Note: This line gives the correct result even when rfind returns -1.
16041584
last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
16051585
spacer = " "
1606-
if _format_options['legacy'] <= 113:
1586+
if current_options['legacy'] <= 113:
16071587
if issubclass(arr.dtype.type, flexible):
16081588
spacer = '\n' + ' '*len(class_name + "(")
16091589
elif last_line_len + len(dtype_str) + 1 > max_line_width:
@@ -1677,7 +1657,7 @@ def _array_str_implementation(
16771657
a, max_line_width=None, precision=None, suppress_small=None,
16781658
array2string=array2string):
16791659
"""Internal version of array_str() that allows overriding array2string."""
1680-
if (_format_options['legacy'] <= 113 and
1660+
if (format_options.get()['legacy'] <= 113 and
16811661
a.shape == () and not a.dtype.names):
16821662
return str(a.item())
16831663

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,7 @@ python_sources = [
13241324
'numerictypes.py',
13251325
'numerictypes.pyi',
13261326
'overrides.py',
1327+
'printoptions.py',
13271328
'records.py',
13281329
'records.pyi',
13291330
'shape_base.py',

numpy/_core/multiarray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
'may_share_memory', 'min_scalar_type', 'ndarray', 'nditer', 'nested_iters',
4040
'normalize_axis_index', 'packbits', 'promote_types', 'putmask',
4141
'ravel_multi_index', 'result_type', 'scalar', 'set_datetimeparse_function',
42-
'set_legacy_print_mode',
4342
'set_typeDict', 'shares_memory', 'typeinfo',
4443
'unpackbits', 'unravel_index', 'vdot', 'where', 'zeros',
4544
'_get_promotion_state', '_set_promotion_state']

numpy/_core/printoptions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
Stores and defines the low-level format_options context variable.
3+
4+
This is defined in its own file outside of the arrayprint module
5+
so we can import it from C while initializing the multiarray
6+
C module during import without introducing circular dependencies.
7+
"""
8+
9+
import sys
10+
from contextvars import ContextVar
11+
12+
__all__ = ["format_options"]
13+
14+
default_format_options_dict = {
15+
"edgeitems": 3, # repr N leading and trailing items of each dimension
16+
"threshold": 1000, # total items > triggers array summarization
17+
"floatmode": "maxprec",
18+
"precision": 8, # precision of floating point representations
19+
"suppress": False, # suppress printing small floating values in exp format
20+
"linewidth": 75,
21+
"nanstr": "nan",
22+
"infstr": "inf",
23+
"sign": "-",
24+
"formatter": None,
25+
# Internally stored as an int to simplify comparisons; converted from/to
26+
# str/False on the way in/out.
27+
'legacy': sys.maxsize,
28+
'override_repr': None,
29+
}
30+
31+
format_options = ContextVar(
32+
"format_options", default=default_format_options_dict.copy())

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "npy_config.h"
3030
#include "npy_pycompat.h"
3131
#include "npy_import.h"
32+
#include "npy_static_data.h"
3233
#include "convert_datatype.h"
3334
#include "legacy_dtype_implementation.h"
3435

@@ -64,7 +65,6 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
6465
#include "ctors.h"
6566
#include "array_assign.h"
6667
#include "common.h"
67-
#include "npy_static_data.h"
6868
#include "cblasfuncs.h"
6969
#include "vdot.h"
7070
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
@@ -99,16 +99,45 @@ _umath_strings_richcompare(
9999
PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip);
100100

101101

102-
static PyObject *
103-
set_legacy_print_mode(PyObject *NPY_UNUSED(self), PyObject *args)
104-
{
105-
if (!PyArg_ParseTuple(args, "i", &npy_thread_unsafe_state.legacy_print_mode)) {
106-
return NULL;
102+
NPY_NO_EXPORT int
103+
get_legacy_print_mode(void) {
104+
/* Get the C value of the legacy printing mode.
105+
*
106+
* It is stored as a Python context variable so we access it via the C
107+
* API. For simplicity the mode is encoded as an integer where INT_MAX
108+
* means no legacy mode, and '113'/'121'/'125' means 1.13/1.21/1.25 legacy
109+
* mode; and 0 maps to INT_MAX. We can upgrade this if we have more
110+
* complex requirements in the future.
111+
*/
112+
PyObject *format_options = NULL;
113+
PyContextVar_Get(npy_static_pydata.format_options, NULL, &format_options);
114+
if (format_options == NULL) {
115+
PyErr_SetString(PyExc_SystemError,
116+
"NumPy internal error: unable to get format_options "
117+
"context variable");
118+
return -1;
107119
}
108-
if (!npy_thread_unsafe_state.legacy_print_mode) {
109-
npy_thread_unsafe_state.legacy_print_mode = INT_MAX;
120+
PyObject *legacy_print_mode = NULL;
121+
if (PyDict_GetItemRef(format_options, npy_interned_str.legacy,
122+
&legacy_print_mode) == -1) {
123+
return -1;
110124
}
111-
Py_RETURN_NONE;
125+
Py_DECREF(format_options);
126+
if (legacy_print_mode == NULL) {
127+
PyErr_SetString(PyExc_SystemError,
128+
"NumPy internal error: unable to get legacy print "
129+
"mode");
130+
return -1;
131+
}
132+
Py_ssize_t ret = PyLong_AsSsize_t(legacy_print_mode);
133+
Py_DECREF(legacy_print_mode);
134+
if (error_converting(ret)) {
135+
return -1;
136+
}
137+
if (ret > INT_MAX) {
138+
return INT_MAX;
139+
}
140+
return (int)ret;
112141
}
113142

114143

@@ -4540,8 +4569,6 @@ static struct PyMethodDef array_module_methods[] = {
45404569
METH_VARARGS | METH_KEYWORDS, NULL},
45414570
{"normalize_axis_index", (PyCFunction)normalize_axis_index,
45424571
METH_FASTCALL | METH_KEYWORDS, NULL},
4543-
{"set_legacy_print_mode", (PyCFunction)set_legacy_print_mode,
4544-
METH_VARARGS, NULL},
45454572
{"_discover_array_parameters", (PyCFunction)_discover_array_parameters,
45464573
METH_FASTCALL | METH_KEYWORDS, NULL},
45474574
{"_get_castingimpl", (PyCFunction)_get_castingimpl,
@@ -4771,8 +4798,6 @@ initialize_thread_unsafe_state(void) {
47714798
npy_thread_unsafe_state.warn_if_no_mem_policy = 0;
47724799
}
47734800

4774-
npy_thread_unsafe_state.legacy_print_mode = INT_MAX;
4775-
47764801
return 0;
47774802
}
47784803

numpy/_core/src/multiarray/multiarraymodule.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,6 @@ typedef struct npy_thread_unsafe_state_struct {
7171
*/
7272
int reload_guard_initialized;
7373

74-
/*
75-
* global variable to determine if legacy printing is enabled,
76-
* accessible from C. For simplicity the mode is encoded as an
77-
* integer where INT_MAX means no legacy mode, and '113'/'121'
78-
* means 1.13/1.21 legacy mode; and 0 maps to INT_MAX. We can
79-
* upgrade this if we have more complex requirements in the future.
80-
*/
81-
int legacy_print_mode;
82-
8374
/*
8475
* Holds the user-defined setting for whether or not to warn
8576
* if there is no memory policy set
@@ -91,5 +82,7 @@ typedef struct npy_thread_unsafe_state_struct {
9182

9283
NPY_VISIBILITY_HIDDEN extern npy_thread_unsafe_state_struct npy_thread_unsafe_state;
9384

85+
NPY_NO_EXPORT int
86+
get_legacy_print_mode(void);
9487

9588
#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */

0 commit comments

Comments
 (0)
0