8000 [WIP] MAINT: Wrap ``printoptions`` in ``ContextVar`` by mtsokol · Pull Request #26345 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

[WIP] MAINT: Wrap printoptions in ContextVar #26345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
MAINT: Wrap printoptions in ContextVar
  • Loading branch information
mtsokol committed Apr 25, 2024
commit f932b7ff5a4d6e5508e270b6c35f32cd4b0e5652
130 changes: 80 additions & 50 deletions numpy/_core/arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from . import multiarray
from .multiarray import (array, dragon4_positional, dragon4_scientific,
datetime_as_string, datetime_data, ndarray,
set_legacy_print_mode)
set_legacy_print_mode as c_set_legacy_print_mode)
from .fromnumeric import any
from .numeric import concatenate, asarray, errstate
from .numerictypes import (longlong, intc, int_, float64, complex128,
Expand All @@ -45,21 +45,29 @@
import operator
import warnings
import contextlib

_format_options = {
'edgeitems': 3, # repr N leading and trailing items of each dimension
'threshold': 1000, # total items > triggers array summarization
'floatmode': 'maxprec',
'precision': 8, # precision of floating point representations
'suppress': False, # suppress printing small floating values in exp format
'linewidth': 75,
'nanstr': 'nan',
'infstr': 'inf',
'sign': '-',
'formatter': None,
from contextvars import ContextVar


_default_format_options_dict = {
"edgeitems": 3, # repr N leading and trailing items of each dimension
"threshold": 1000, # total items > triggers array summarization
"floatmode": "maxprec",
"precision": 8, # precision of floating point representations
"suppress": False, # suppress printing small floating values in exp format
"linewidth": 75,
"nanstr": "nan",
"infstr": "inf",
"sign": "-",
"formatter": None,
# Internally stored as an int to simplify comparisons; converted from/to
# str/False on the way in/out.
'legacy': sys.maxsize}
"legacy": sys.maxsize,
}


_format_options = ContextVar(
"format_options", default=_default_format_options_dict.copy())


def _make_options_dict(precision=None, threshold=None, edgeitems=None,
linewidth=None, suppress=None, nanstr=None, infstr=None,
Expand Down Expand Up @@ -115,6 +123,43 @@ def _make_options_dict(precision=None, threshold=None, edgeitems=None,
return options


def _set_legacy_print_mode(format_options: dict) -> None:
# set the C variable for legacy mode
if format_options['legacy'] == 113:
c_set_legacy_print_mode(113)
# reset the sign option in legacy mode to avoid confusion
format_options['sign'] = '-'
elif format_options['legacy'] == 121:
c_set_legacy_print_mode(121)
elif format_options['legacy'] == 125:
c_set_legacy_print_mode(125)
elif format_options['legacy'] == sys.maxsize:
c_set_legacy_print_mode(0)


def _set_printoptions(precision=None, threshold=None, edgeitems=None,
linewidth=None, suppress=None, nanstr=None,
infstr=None, formatter=None, sign=None, floatmode=None,
legacy=None):
"""
...
"""
new_opt = _make_options_dict(precision, threshold, edgeitems, linewidth,
suppress, nanstr, infstr, sign, formatter,
floatmode, legacy)
# formatter is always reset
new_opt['formatter'] = formatter

current_opt = _format_options.get().copy()
current_opt.update(new_opt)
updated_opt = current_opt

_set_legacy_print_mode(updated_opt)

token = _format_options.set(updated_opt)
return token


@set_module('numpy')
def set_printoptions(precision=None, threshold=None, edgeitems=None,
linewidth=None, suppress=None, nanstr=None,
Expand Down Expand Up @@ -283,24 +328,8 @@ def set_printoptions(precision=None, threshold=None, edgeitems=None,
array([ 0. , 1.11, 2.22, ..., 7.78, 8.89, 10. ])

"""
opt = _make_options_dict(precision, threshold, edgeitems, linewidth,
suppress, nanstr, infstr, sign, formatter,
floatmode, legacy)
# formatter is always reset
opt['formatter'] = formatter
_format_options.update(opt)

# set the C variable for legacy mode
if _format_options['legacy'] == 113:
set_legacy_print_mode(113)
# reset the sign option in legacy mode to avoid confusion
_format_options['sign'] = '-'
elif _format_options['legacy'] == 121:
set_legacy_print_mode(121)
elif _format_options['legacy'] == 125:
set_legacy_print_mode(125)
elif _format_options['legacy'] == sys.maxsize:
set_legacy_print_mode(0)
_set_printoptions(precision, threshold, edgeitems, linewidth, suppress,
nanstr, infstr, formatter, sign, floatmode, legacy)


@set_module('numpy')
Expand Down Expand Up @@ -330,7 +359,7 @@ def get_printoptions():
set_printoptions, printoptions

"""
opts = _format_options.copy()
opts = _format_options.get().copy()
opts['legacy'] = {
113: '1.13', 121: '1.21', 125: '1.25', sys.maxsize: False,
}[opts['legacy']]
Expand All @@ -339,7 +368,7 @@ def get_printoptions():

def _get_legacy_print_mode():
"""Return the legacy print mode as an int."""
return _format_options['legacy']
return _format_options.get()['legacy']


@set_module('numpy')
Expand Down Expand Up @@ -369,12 +398,12 @@ def printoptions(*args, **kwargs):
set_printoptions, get_printoptions

"""
opts = np.get_printoptions()
try:
np.set_printoptions(*args, **kwargs)
yield np.get_printoptions()
token = _set_printoptions(*args, **kwargs)
yield get_printoptions()
finally:
np.set_printoptions(**opts)
_format_options.reset(token)
_set_legacy_print_mode(_format_options.get())


def _leading_trailing(a, edgeitems, index=()):
Expand Down Expand Up @@ -732,7 +761,7 @@ def array2string(a, max_line_width=None, precision=None,
overrides = _make_options_dict(precision, threshold, edgeitems,
max_line_width, suppress_small, None, None,
sign, formatter, floatmode, legacy)
options = _format_options.copy()
options = _format_options.get().copy()
options.update(overrides)

if options['legacy'] <= 113:
Expand Down Expand Up @@ -1037,8 +1066,8 @@ def fillFormat(self, data):
# if there are non-finite values, may need to increase pad_left
if data.size != finite_vals.size:
neginf = self.sign != '-' or any(data[isinf(data)] < 0)
nanlen = len(_format_options['nanstr'])
inflen = len(_format_options['infstr']) + neginf
nanlen = len(_format_options.get()['nanstr'])
inflen = len(_format_options.get()['infstr']) + neginf
offset = self.pad_right + 1 # +1 for decimal pt
self.pad_left = max(
self.pad_left, nanlen - offset, inflen - offset
Expand All @@ -1049,10 +1078,10 @@ def __call__(self, x):
with errstate(invalid='ignore'):
if np.isnan(x):
sign = '+' if self.sign == '+' else ''
ret = sign + _format_options['nanstr']
ret = sign + _format_options.get()['nanstr']
else: # isinf
sign = '-' if x < 0 else '+' if self.sign == '+' else ''
ret = sign + _format_options['infstr']
ret = sign + _format_options.get()['infstr']
return ' '*(
self.pad_left + self.pad_right + 1 - len(ret)
) + ret
Expand Down Expand Up @@ -1443,10 +1472,11 @@ def _void_scalar_to_string(x, is_repr=True):
scalartypes.c.src code, and is placed here because it uses the elementwise
formatters defined above.
"""
options = _format_options.copy()
options = _format_options.get().copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it makes sense, but I was half hoping we can just stop passing options around completely and enter a new context ourselves when necessary. contextvar lookup is almost as fast as dict lookup, and global lookup also not really expensive at the function level (within a loop it might be).


if options["legacy"] <= 125:
return StructuredVoidFormat.from_data(array(x), **_format_options)(x)
return StructuredVoidFormat.from_data(
array(x), **_format_options.get())(x)

if options.get('formatter') is None:
options['formatter'] = {}
Expand Down Expand Up @@ -1490,7 +1520,7 @@ def dtype_is_implied(dtype):
array([1, 2, 3], dtype=int8)
"""
dtype = np.dtype(dtype)
if _format_options['legacy'] <= 113 and dtype.type == np.bool:
if _format_options.get()['legacy'] <= 113 and dtype.type == np.bool:
return False

# not just void types can be structured, and names are not part of the repr
Expand Down Expand Up @@ -1541,7 +1571,7 @@ def _array_repr_implementation(
array2string=array2string):
"""Internal version of array_repr() that allows overriding array2string."""
if max_line_width is None:
max_line_width = _format_options['linewidth']
max_line_width = _format_options.get()['linewidth']

if type(arr) is not ndarray:
class_name = type(arr).__name__
Expand All @@ -1553,7 +1583,7 @@ def _array_repr_implementation(
prefix = class_name + "("
suffix = ")" if skipdtype else ","

if (_format_options['legacy'] <= 113 and
if (_format_options.get()['legacy'] <= 113 and
arr.shape == () and not arr.dtype.names):
lst = repr(arr.item())
elif arr.size > 0 or arr.shape == (0,):
Expand All @@ -1574,7 +1604,7 @@ def _array_repr_implementation(
# Note: This line gives the correct result even when rfind returns -1.
last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
spacer = " "
if _format_options['legacy'] <= 113:
if _format_options.get()['legacy'] <= 113:
if issubclass(arr.dtype.type, flexible):
spacer = '\n' + ' '*len(class_name + "(")
elif last_line_len + len(dtype_str) + 1 > max_line_width:
Expand Down Expand Up @@ -1648,7 +1678,7 @@ def _array_str_implementation(
a, max_line_width=None, precision=None, suppress_small=None,
array2string=array2string):
"""Internal version of array_str() that allows overriding array2string."""
if (_format_options['legacy'] <= 113 and
if (_format_options.get()['legacy'] <= 113 and
a.shape == () and not a.dtype.names):
return str(a.item())

Expand Down
0