8000 MAINT: respond to review comments · numpy/numpy@02506cc · GitHub
[go: up one dir, main page]

Skip to content

Commit 02506cc

Browse files
committed
MAINT: respond to review comments
1 parent 23fc1e4 commit 02506cc

File tree

3 files changed

+136
-26
lines changed

3 files changed

+136
-26
lines changed

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ _umath_strings_richcompare(
102102

103103
NPY_NO_EXPORT int
104104
get_legacy_print_mode(void) {
105-
/*
106-
* Get the C value of the legacy printing mode.
107-
* accessible from C. For simplicity the mode is encoded as an
108-
* integer where INT_MAX means no legacy mode, and '113'/'121'/'125'
109-
* means 1.13/1.21/1.25 legacy mode; and 0 maps to INT_MAX. We can
110-
* upgrade this if we have more complex requirements in the future.
105+
/* Get the C value of the legacy printing mode.
106+
*
107+
* It is stored as a Python context variable so we access it via the C
108+
* API. For simplicity the mode is encoded as an integer where INT_MAX
109+
* means no legacy mode, and '113'/'121'/'125' means 1.13/1.21/1.25 legacy
110+
* mode; and 0 maps to INT_MAX. We can upgrade this if we have more
111+
* complex requirements in the future.
111112
*/
112113
PyObject *format_options = NULL;
113114
PyContextVar_Get(npy_static_pydata.format_options, NULL, &format_options);
@@ -127,17 +128,16 @@ get_legacy_print_mode(void) {
127128
PyErr_SetString(PyExc_SystemError,
128129
"NumPy internal error: unable to get legacy print "
129130
"mode");
131+
return -1;
130132
}
131-
long long ret = PyLong_AsLongLong(legacy_print_mode);
133+
Py_ssize_t ret = PyLong_AsSsize_t(legacy_print_mode);
132134
Py_DECREF(legacy_print_mode);
133135
if ((ret == -1) && PyErr_Occurred()) {
134136
return -1;
135137
}
136138
if (ret > INT_MAX) {
137139
return INT_MAX;
138140
}
139-
// in principle this cast to int could overflow, in practice ret is never
140-
// bigger than 125 and we explicitly check for values greater than INT_MAX above.
141141
return (int)ret;
142142
}
143143

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,11 @@ genint_type_repr(PyObject *self)
336336
if (value_string == NULL) {
337337
return NULL;
338338
}
339-
if (get_legacy_print_mode() <= 125) {
339+
int legacy_print_mode = get_legacy_print_mode();
340+
if (legacy_print_mode == -1) {
341+
return NULL;
342+
}
343+
if (legacy_print_mode <= 125) {
340344
return value_string;
341345
}
342346

@@ -373,7 +377,11 @@ genbool_type_str(PyObject *self)
373377
static PyObject *
374378
genbool_type_repr(PyObject *self)
375379
{
376-
if (get_legacy_print_mode() <= 125) {
380+
int legacy_print_mode = get_legacy_print_mode();
381+
if (legacy_print_mode == -1) {
382+
return NULL;
383+
}
384+
if (legacy_print_mode <= 125) {
377385
return genbool_type_str(self);
378386
}
379387
return PyUnicode_FromString(
@@ -499,7 +507,11 @@ stringtype_@form@(PyObject *self)
499507
if (ret == NULL) {
500508
return NULL;
501509
}
502-
if (get_legacy_print_mode() > 125) {
510+
int legacy_print_mode = get_legacy_print_mode();
511+
if (legacy_print_mode == -1) {
512+
return NULL;
513+
}
514+
if (legacy_print_mode > 125) {
503515
Py_SETREF(ret, PyUnicode_FromFormat("np.bytes_(%S)", ret));
504516
}
505517
#endif /* IS_repr */
@@ -546,7 +558,11 @@ unicodetype_@form@(PyObject *self)
546558
if (ret == NULL) {
547559
return NULL;
548560
}
549-
if (get_legacy_print_mode() > 125) {
561+
int legacy_print_mode = get_legacy_print_mode();
562+
if (legacy_print_mode == -1) {
563+
return NULL;
564+
}
565+
if (legacy_print_mode > 125) {
550566
Py_SETREF(ret, PyUnicode_FromFormat("np.str_(%S)", ret));
551567
}
552568
#endif /* IS_repr */
@@ -626,7 +642,11 @@ voidtype_repr(PyObject *self)
626642
/* Python helper checks for the legacy mode printing */
627643
return _void_scalar_to_string(self, 1);
628644
}
629-
if (get_legacy_print_mode() > 125) {
645+
int legacy_print_mode = get_legacy_print_mode();
646+
if (legacy_print_mode == -1) {
647+
return NULL;
648+
}
649+
if (legacy_print_mode > 125) {
630650
return _void_to_hex(s->obval, s->descr->elsize, "np.void(b'", "\\x", "')");
631651
}
632652
else {
@@ -678,7 +698,11 @@ datetimetype_repr(PyObject *self)
678698
*/
679699
if ((scal->obmeta.num == 1 && scal->obmeta.base != NPY_FR_h) ||
680700
scal->obmeta.base == NPY_FR_GENERIC) {
681-
if (get_legacy_print_mode() > 125) {
701+
int legacy_print_mode = get_legacy_print_mode();
702+
if (legacy_print_mode == -1) {
703+
return NULL;
704+
}
705+
if (legacy_print_mode > 125) {
682706
ret = PyUnicode_FromFormat("np.datetime64('%s')", iso);
683707
}
684708
else {
@@ -690,7 +714,11 @@ datetimetype_repr(PyObject *self)
690714
if (meta == NULL) {
691715
return NULL;
692716
}
693-
if (get_legacy_print_mode() > 125) {
717+
int legacy_print_mode = get_legacy_print_mode();
718+
if (legacy_print_mode == -1) {
719+
return NULL;
720+
}
721+
if (legacy_print_mode > 125) {
694722
ret = PyUnicode_FromFormat("np.datetime64('%s','%S')", iso, meta);
695723
}
696724
else {
@@ -734,7 +762,11 @@ timedeltatype_repr(PyObject *self)
734762

735763
/* The metadata unit */
736764
if (scal->obmeta.base == NPY_FR_GENERIC) {
737-
if (get_legacy_print_mode() > 125) {
765+
int legacy_print_mode = get_legacy_print_mode();
766+
if (legacy_print_mode == -1) {
767+
return NULL;
768+
}
769+
if (legacy_print_mode > 125) {
738770
ret = PyUnicode_FromFormat("np.timedelta64(%S)", val);
739771
}
740772
else {
@@ -747,7 +779,11 @@ timedeltatype_repr(PyObject *self)
747779
Py_DECREF(val);
748780
return NULL;
749781
}
750-
if (get_legacy_print_mode() > 125) {
782+
int legacy_print_mode = get_legacy_print_mode();
783+
if (legacy_print_mode == -1) {
784+
return NULL;
785+
}
786+
if (legacy_print_mode > 125) {
751787
ret = PyUnicode_FromFormat("np.timedelta64(%S,'%S')", val, meta);
752788
}
753789
else {
@@ -1049,7 +1085,11 @@ static PyObject *
10491085
npy_bool sign)
10501086
{
10511087

1052-
if (get_legacy_print_mode() <= 113) {
1088+
int legacy_print_mode = get_legacy_print_mode();
1089+
if (legacy_print_mode == -1) {
1090+
return NULL;
1091+
}
1092+
if (legacy_print_mode <= 113) {
10531093
return legacy_@name@_format@kind@(val);
10541094
}
10551095

@@ -1080,7 +1120,11 @@ static PyObject *
10801120
if (string == NULL) {
10811121
return NULL;
10821122
}
1083-
if (get_legacy_print_mode() > 125) {
1123+
int legacy_print_mode = get_legacy_print_mode();
1124+
if (legacy_print_mode == -1) {
1125+
return NULL;
1126+
}
1127+
if (legacy_print_mode > 125) {
10841128
Py_SETREF(string, PyUnicode_FromFormat("@repr_format@", string));
10851129
}
10861130
#endif /* IS_repr */
@@ -1095,7 +1139,11 @@ c@name@type_@kind@(PyObject *self)
10951139
npy_c@name@ val = PyArrayScalar_VAL(self, C@Name@);
10961140
TrimMode trim = TrimMode_DptZeros;
10971141

1098-
if (get_legacy_print_mode() <= 113) {
1142+
int legacy_print_mode = get_legacy_print_mode();
1143+
if (legacy_print_mode == -1) {
1144+
return NULL;
1145+
}
1146+
if (legacy_print_mode <= 113) {
10991147
return legacy_c@name@_format@kind@(val);
11001148
}
11011149

@@ -1108,7 +1156,11 @@ c@name@type_@kind@(PyObject *self)
11081156
#ifdef IS_str
11091157
ret = PyUnicode_FromFormat("%Sj", istr);
11101158
#else /* IS_repr */
1111-
if (get_legacy_print_mode() <= 125) {
1159+
int legacy_print_mode = get_legacy_print_mode();
1160+
if (legacy_print_mode == -1) {
1161+
return NULL;
1162+
}
1163+
if (legacy_print_mode <= 125) {
11121164
ret = PyUnicode_FromFormat("%Sj", istr);
11131165
}
11141166
else {
@@ -1156,7 +1208,11 @@ c@name@type_@kind@(PyObject *self)
11561208
#ifdef IS_str
11571209
string = PyUnicode_FromFormat("(%S%Sj)", rstr, istr);
11581210
#else /* IS_repr */
1159-
if (get_legacy_print_mode() > 125) {
1211+
legacy_print_mode = get_legacy_print_mode();
1212+
if (legacy_print_mode == -1) {
1213+
return NULL;
1214+
}
1215+
if (legacy_print_mode > 125) {
11601216
string = PyUnicode_FromFormat("@crepr_format@", rstr, istr);
11611217
}
11621218
else {
@@ -1181,7 +1237,11 @@ halftype_@kind@(PyObject *self)
11811237
float floatval = npy_half_to_float(val);
11821238
float absval;
11831239

1184-
if (get_legacy_print_mode() <= 113) {
1240+
int legacy_print_mode = get_legacy_print_mode();
1241+
if (legacy_print_mode == -1) {
1242+
return NULL;
1243+
}
1244+
if (legacy_print_mode <= 113) {
11851245
return legacy_float_format@kind@(floatval);
11861246
}
11871247

@@ -1197,7 +1257,11 @@ halftype_@kind@(PyObject *self)
11971257
#ifdef IS_str
11981258
return string;
11991259
#else
1200-
if (string == NULL || get_legacy_print_mode() <= 125) {
1260+
legacy_print_mode = get_legacy_print_mode();
1261+
if (legacy_print_mode == -1) {
1262+
return NULL;
1263+
}
1264+
if (string == NULL || legacy_print_mode <= 125) {
12011265
return string;
12021266
}
12031267
PyObject *res = PyUnicode_FromFormat("np.float16(%S)", string);

numpy/_core/tests/test_arrayprint.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from numpy.testing import (
99
assert_, assert_equal, assert_raises, assert_warns, HAS_REFCOUNT,
10-
assert_raises_regex,
10+
assert_raises_regex, IS_WASM
1111
)
1212
from numpy._core.arrayprint import _typelessdata
1313
import textwrap
@@ -1200,3 +1200,49 @@ def test_scalar_void_float_str():
12001200
# we do not do that.
12011201
scalar = np.void((1.0, 2.0), dtype=[('f0', '<f8'), ('f1', '>f4')])
12021202
assert str(scalar) == "(1.0, 2.0)"
1203+
1204+
@pytest.mark.skipif(IS_WASM, reason="wasm doesn't support asyncio")
1205+
def test_printoptions_asyncio_safe():
1206+
asyncio = pytest.importorskip("asyncio")
1207+
1208+
b = asyncio.Barrier(2)
1209+
1210+
async def legacy_113():
1211+
np.set_printoptions(legacy='1.13', precision=12)
1212+
await b.wait()
1213+
po = np.get_printoptions()
1214+
assert po['legacy'] == '1.13'
1215+
assert po['precision'] == 12
1216+
orig_linewidth = po['linewidth']
1217+
with np.printoptions(linewidth=34, legacy='1.21'):
1218+
po = np.get_printoptions()
1219+
assert po['legacy'] == '1.21'
1220+
assert po['precision'] == 12
1221+
assert po['linewidth'] == 34
1222+
po = np.get_printoptions()
1223+
assert po['linewidth'] == orig_linewidth
1224+
assert po['legacy'] == '1.13'
1225+
assert po['precision'] == 12
1226+
1227+
async def legacy_125():
1228+
np.set_printoptions(legacy='1.25', precision=7)
1229+
await b.wait()
1230+
po = np.get_printoptions()
1231+
assert po['legacy'] == '1.25'
1232+
assert po['precision'] == 7
1233+
orig_linewidth = po['linewidth']
1234+
with np.printoptions(linewidth=6, legacy='1.13'):
1235+
po = np.get_printoptions()
1236+
assert po['legacy'] == '1.13'
1237+
assert po['precision'] == 7
1238+
assert po['linewidth'] == 6
1239+
po = np.get_printoptions()
1240+
assert po['linewidth'] == orig_linewidth
1241+
assert po['legacy'] == '1.25'
1242+
assert po['precision'] == 7
1243+
1244+
async def main():
1245+
await asyncio.gather(legacy_125(), legacy_125())
1246+
1247+
loop = asyncio.new_event_loop()
1248+
asyncio.run(main())

0 commit comments

Comments
 (0)
0