8000 Merge pull request #17295 from seberg/issue-17294 · numpy/numpy@d62b0ee · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d62b0ee

Browse files
authored
Merge pull request #17295 from seberg/issue-17294
BUG,ENH: fix pickling user-scalars by allowing non-format buffer export
2 parents 4c83c04 + d02ca96 commit d62b0ee

File tree

5 files changed

+106
-27
lines changed

5 files changed

+106
-27
lines changed

numpy/core/src/multiarray/buffer.c

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ static PyObject *_buffer_info_cache = NULL;
456456

457457
/* Fill in the info structure */
458458
static _buffer_info_t*
459-
_buffer_info_new(PyObject *obj, npy_bool f_contiguous)
459+
_buffer_info_new(PyObject *obj, int flags)
460460
{
461461
/*
462462
* Note that the buffer info is cached as PyLongObjects making them appear
@@ -514,6 +514,7 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
514514
* (This is unnecessary, but has no effect in the case where
515515
* NPY_RELAXED_STRIDES CHECKING is disabled.)
516516
*/
517+
int f_contiguous = (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS;
517518
if (PyArray_IS_C_CONTIGUOUS(arr) && !(
518519
f_contiguous && PyArray_IS_F_CONTIGUOUS(arr))) {
519520
Py_ssize_t sd = PyArray_ITEMSIZE(arr);
@@ -547,16 +548,20 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
547548
}
548549

549550
/* Fill in format */
550-
err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
551-
Py_DECREF(descr);
552-
if (err != 0) {
553-
goto fail;
551+
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
552+
err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
553+
Py_DECREF(descr);
554+
if (err != 0) {
555+
goto fail;
556+
}
557+
if (_append_char(&fmt, '\0') < 0) {
558+
goto fail;
559+
}
560+
info->format = fmt.s;
554561
}
555-
if (_append_char(&fmt, '\0') < 0) {
556-
goto fail;
562+
else {
563+
info->format = NULL;
557564
}
558-
info->format = fmt.s;
559-
560565
return info;
561566

562567
fail:
@@ -572,9 +577,10 @@ _buffer_info_cmp(_buffer_info_t *a, _buffer_info_t *b)
572577
Py_ssize_t c;
573578
int k;
574579

575-
c = strcmp(a->format, b->format);
576-
if (c != 0) return c;
577-
580+
if (a->format != NULL && b->format != NULL) {
581+
c = strcmp(a->format, b->format);
582+
if (c != 0) return c;
583+
}
578584
c = a->ndim - b->ndim;
579585
if (c != 0) return c;
580586

@@ -599,7 +605,7 @@ _buffer_info_free(_buffer_info_t *info)
599605

600606
/* Get buffer info from the global dictionary */
601607
static _buffer_info_t*
602-
_buffer_get_info(PyObject *obj, npy_bool f_contiguous)
608+
_buffer_get_info(PyObject *obj, int flags)
603609
{
604610
PyObject *key = NULL, *item_list = NULL, *item = NULL;
605611
_buffer_info_t *info = NULL, *old_info = NULL;
@@ -612,7 +618,7 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
612618
}
613619

614620
/* Compute information */
615-
info = _buffer_info_new(obj, f_contiguous);
621+
info = _buffer_info_new(obj, flags);
616622
if (info == NULL) {
617623
return NULL;
618624
}
@@ -630,11 +636,9 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
630636
if (item_list_length > 0) {
631637
item = PyList_GetItem(item_list, item_list_length - 1);
632638
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
633-
if (_buffer_info_cmp(info, old_info) == 0) {
634-
_buffer_info_free(info);
635-
info = old_info;
636-
}
637-
else {
639+
if (_buffer_info_cmp(info, old_info) != 0) {
640+
old_info = NULL; /* Can't use this one, but possibly next */
641+
638642
if (item_list_length > 1 && info->ndim > 1) {
639643
/*
640644
* Some arrays are C- and F-contiguous and if they have more
@@ -648,12 +652,26 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
648652
*/
649653
item = PyList_GetItem(item_list, item_list_length - 2);
650654
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
651-
if (_buffer_info_cmp(info, old_info) == 0) {
652-
_buffer_info_free(info);
653-
info = old_info;
655+
if (_buffer_info_cmp(info, old_info) != 0) {
656+
old_info = NULL;
654657
}
655658
}
656659
}
660+
661+
if (old_info != NULL) {
662+
/*
663+
* The two info->format are considered equal if one of them
664+
* has no format set (meaning the format is arbitrary and can
665+
* be modified). If the new info has a format, but we reuse
666+
* the old one, this transfers the ownership to the old one.
667+
*/
668+
if (old_info->format == NULL) {
669+
old_info->format = info->format;
670+
info->format = NULL;
671+
}
672+
_buffer_info_free(info);
673+
info = old_info;
674+
}
657675
}
658676
}
659677
else {
@@ -760,7 +778,7 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
760778
}
761779

762780
/* Fill in information */
763-
info = _buffer_get_info(obj, (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS);
781+
info = _buffer_get_info(obj, flags);
764782
if (info == NULL) {
765783
goto fail;
766784
}
@@ -825,7 +843,7 @@ void_getbuffer(PyObject *self, Py_buffer *view, int flags)
825843
}
826844

827845
/* Fill in information */
828-
info = _buffer_get_info(self, 0);
846+
info = _buffer_get_info(self, flags);
829847
if (info == NULL) {
830848
goto fail;
831849
}

numpy/core/src/multiarray/scalarapi.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
3535
{
3636
int type_num;
3737
int align;
38-
npy_intp memloc;
38+
uintptr_t memloc;
3939
if (descr == NULL) {
4040
descr = PyArray_DescrFromScalar(scalar);
4141
type_num = descr->type_num;
@@ -168,7 +168,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
168168
* Use the alignment flag to figure out where the data begins
169169
* after a PyObject_HEAD
170170
*/
171-
memloc = (npy_intp)scalar;
171+
memloc = (uintptr_t)scalar;
172172
memloc += sizeof(PyObject);
173173
/* now round-up to the nearest alignment value */
174174
align = descr->alignment;

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,50 @@ static PySequenceMethods voidtype_as_sequence = {
23832383
};
23842384

23852385

2386+
/*
2387+
* This function implements simple buffer export for user defined subclasses
2388+
* of `np.generic`. All other scalar types override the buffer export.
2389+
*/
2390+
static int
2391+
gentype_arrtype_getbuffer(PyObject *self, Py_buffer *view, int flags)
2392+
{
2393+
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
2394+
PyErr_Format(PyExc_TypeError,
2395+
"NumPy scalar %R can only exported as a buffer without format.",
2396+
self);
2397+
return -1;
2398+
}
2399+
PyArray_Descr *descr = PyArray_DescrFromScalar(self);
2400+
if (descr == NULL) {
2401+
return -1;
2402+
}
2403+
if (!PyDataType_ISUSERDEF(descr)) {
2404+
/* This path would also reject the (hopefully) impossible "object" */
2405+
PyErr_Format(PyExc_TypeError,
2406+
"user-defined scalar %R registered for built-in dtype %S? "
2407+
"This should be impossible.",
2408+
self, descr);
2409+
return -1;
2410+
}
2411+
view->ndim = 0;
2412+
view->len = descr->elsize;
2413+
view->itemsize = descr->elsize;
2414+
view->shape = NULL;
2415+
view->strides = NULL;
2416+
view->suboffsets = NULL;
2417+
Py_INCREF(self);
2418+
view->obj = self;
2419+
view->buf = scalar_value(self, descr);
2420+
Py_DECREF(descr);
2421+
view->format = NULL;
2422+
return 0;
2423+
}
2424+
2425+
2426+
static PyBufferProcs gentype_arrtype_as_buffer = {
2427+
.bf_getbuffer = (getbufferproc)gentype_arrtype_getbuffer,
2428+
};
2429+
23862430

23872431
/**begin repeat
23882432
* #name = bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong,
@@ -3794,6 +3838,7 @@ initialize_numeric_types(void)
37943838
PyGenericArrType_Type.tp_alloc = gentype_alloc;
37953839
PyGenericArrType_Type.tp_free = (freefunc)gentype_free;
37963840
PyGenericArrType_Type.tp_richcompare = gentype_richcompare;
3841+
PyGenericArrType_Type.tp_as_buffer = &gentype_arrtype_as_buffer;
37973842

37983843
PyBoolArrType_Type.tp_as_number = &bool_arrtype_as_number;
37993844
/*

numpy/core/src/umath/_rational_tests.c.src

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ static PyGetSetDef pyrational_getset[] = {
663663

664664
static PyTypeObject PyRational_Type = {
665665
PyVarObject_HEAD_INIT(NULL, 0)
666-
"rational", /* tp_name */
666+
"numpy.core._rational_tests.rational", /* tp_name */
667667
sizeof(PyRational), /* tp_basicsize */
668668
0, /* tp_itemsize */
669669
0, /* tp_dealloc */

numpy/core/tests/test_multiarray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424
import numpy.core._multiarray_tests as _multiarray_tests
25+
from numpy.core._rational_tests import rational
2526
from numpy.testing import (
2627
assert_, assert_raises, assert_warns, assert_equal, assert_almost_equal,
2728
assert_array_equal, assert_raises_regex, assert_array_almost_equal,
@@ -7143,6 +7144,21 @@ def test_export_flags(self):
71437144
_multiarray_tests.get_buffer_info,
71447145
np.arange(5)[::2], ('SIMPLE',))
71457146

7147+
@pytest.mark.parametrize(["obj", "error"], [
7148+
pytest.param(np.array([1, 2], dtype=rational), ValueError, id="array"),
7149+
pytest.param(rational(1, 2), TypeError, id="scalar")])
7150+
def test_export_and_pickle_user_dtype(self, obj, error):
7151+
# User dtypes should export successfully when FORMAT was not requested.
7152+
with pytest.raises(error):
7153+
_multiarray_tests.get_buffer_info(obj, ("STRIDED", "FORMAT"))
7154+
7155+
_multiarray_tests.get_buffer_info(obj, ("STRIDED",))
7156+
7157+
# This is currently also necessary to implement pickling:
7158+
pickle_obj = pickle.dumps(obj)
7159+
res = pickle.loads(pickle_obj)
7160+
assert_array_equal(res, obj)
7161+
71467162
def test_padding(self):
71477163
for j in range(8):
71487164
x = np.array([(1,), (2,)], dtype={'f0': (int, j)})

0 commit comments

Comments
 (0)
0