8000 Merge pull request #4888 from pv/fix-bytes-encoding-unpickle · numpy/numpy@857c5e2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 857c5e2

Browse files
committed
Merge pull request #4888 from pv/fix-bytes-encoding-unpickle
ENH: core: make unpickling with encoding='bytes' work
2 parents bd0a4f3 + 16f39c8 commit 857c5e2

File tree

2 files changed

+166
-50
lines changed

2 files changed

+166
-50
lines changed

numpy/core/src/multiarray/descriptor.c

Lines changed: 131 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,11 +2369,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
23692369
{
23702370
int elsize = -1, alignment = -1;
23712371
int version = 4;
2372-
#if defined(NPY_PY3K)
2373-
int endian;
2374-
#else
23752372
char endian;
2376-
#endif
2373+
PyObject *endian_obj;
23772374
PyObject *subarray, *fields, *names = NULL, *metadata=NULL;
23782375
int incref_names = 1;
23792376
int int_dtypeflags = 0;
@@ -2390,68 +2387,39 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
23902387
}
23912388
switch (PyTuple_GET_SIZE(PyTuple_GET_ITEM(args,0))) {
23922389
case 9:
2393-
#if defined(NPY_PY3K)
2394-
#define _ARGSTR_ "(iCOOOiiiO)"
2395-
#else
2396-
#define _ARGSTR_ "(icOOOiiiO)"
2397-
#endif
2398-
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
2390+
if (!PyArg_ParseTuple(args, "(iOOOOiiiO)", &version, &endian_obj,
23992391
&subarray, &names, &fields, &elsize,
24002392
&alignment, &int_dtypeflags, &metadata)) {
2393+
PyErr_Clear();
24012394
return NULL;
2402-
#undef _ARGSTR_
24032395
}
24042396
break;
24052397
case 8:
2406-
#if defined(NPY_PY3K)
2407-
#define _ARGSTR_ "(iCOOOiii)"
2408-
#else
2409-
#define _ARGSTR_ "(icOOOiii)"
2410-
#endif
2411-
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
2398+
if (!PyArg_ParseTuple(args, "(iOOOOiii)", &version, &endian_obj,
24122399
&subarray, &names, &fields, &elsize,
24132400
&alignment, &int_dtypeflags)) {
24142401
return NULL;
2415-
#undef _ARGSTR_
24162402
}
24172403
break;
24182404
case 7:
2419-
#if defined(NPY_PY3K)
2420-
#define _ARGSTR_ "(iCOOOii)"
2421-
#else
2422-
#define _ARGSTR_ "(icOOOii)"
2423-
#endif
2424-
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
2405+
if (!PyArg_ParseTuple(args, "(iOOOOii)", &version, &endian_obj,
24252406
&subarray, &names, &fields, &elsize,
24262407
&alignment)) {
24272408
return NULL;
2428-
#undef _ARGSTR_
24292409
}
24302410
break;
24312411
case 6:
2432-
#if defined(NPY_PY3K)
2433-
#define _ARGSTR_ "(iCOOii)"
2434-
#else
2435-
#define _ARGSTR_ "(icOOii)"
2436-
#endif
2437-
if (!PyArg_ParseTuple(args, _ARGSTR_, &version,
2438-
&endian, &subarray, &fields,
2412+
if (!PyArg_ParseTuple(args, "(iOOOii)", &version,
2413+
&endian_obj, &subarray, &fields,
24392414
&elsize, &alignment)) {
2440-
PyErr_Clear();
2441-
#undef _ARGSTR_
2415+
return NULL;
24422416
}
24432417
break;
24442418
case 5:
24452419
version = 0;
2446-
#if defined(NPY_PY3K)
2447-
#define _ARGSTR_ "(COOii)"
2448-
#else
2449-
#define _ARGSTR_ "(cOOii)"
2450-
#endif
2451-
if (!PyArg_ParseTuple(args, _ARGSTR_,
2452-
&endian, &subarray, &fields, &elsize,
2420+
if (!PyArg_ParseTuple(args, "(OOOii)",
2421+
&endian_obj, &subarray, &fields, &elsize,
24532422
&alignment)) {
2454-
#undef _ARGSTR_
24552423
return NULL;
24562424
}
24572425
break;
@@ -2494,11 +2462,55 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
24942462
}
24952463
}
24962464

2465+
/* Parse endian */
2466+
if (PyUnicode_Check(endian_obj) || PyBytes_Check(endian_obj)) {
2467+
PyObject *tmp = NULL;
2468+
char *str;
2469+
Py_ssize_t len;
2470+
2471+
if (PyUnicode_Check(endian_obj)) {
2472+
tmp = PyUnicode_AsASCIIString(endian_obj);
2473+
if (tmp == NULL) {
2474+
return NULL;
2475+
}
2476+
endian_obj = tmp;
2477+
}
2478+
2479+
if (PyBytes_AsStringAndSize(endian_obj, &str, &len) == -1) {
2480+
Py_XDECREF(tmp);
2481+
return NULL;
2482+
}
2483+
if (len != 1) {
2484+
PyErr_SetString(PyExc_ValueError,
2485+
"endian is not 1-char string in Numpy dtype unpickling");
2486+
Py_XDECREF(tmp);
2487+
return NULL;
2488+
}
2489+
endian = str[0];
2490+
Py_XDECREF(tmp);
2491+
}
2492+
else {
2493+
PyErr_SetString(PyExc_ValueError,
2494+
"endian is not a string in Numpy dtype unpickling");
2495+
return NULL;
2496+
}
24972497

24982498
if ((fields == Py_None && names != Py_None) ||
24992499
(names == Py_None && fields != Py_None)) {
25002500
PyErr_Format(PyExc_ValueError,
2501-
"inconsistent fields and names");
2501+
"inconsistent fields and names in Numpy dtype unpickling");
2502+
return NULL;
2503+
}
2504+
2505+
if (names != Py_None && !PyTuple_Check(names)) {
2506+
PyErr_Format(PyExc_ValueError,
2507+
"non-tuple names in Numpy dtype unpickling");
2508+
return NULL;
2509+
}
2510+
2511+
if (fields != Py_None && !PyDict_Check(fields)) {
2512+
PyErr_Format(PyExc_ValueError,
2513+
"non-dict fields in Numpy dtype unpickling");
25022514
return NULL;
25032515
}
25042516

@@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
25632575
}
25642576

25652577
if (fields != Py_None) {
2566-
Py_XDECREF(self->fields);
2567-
self->fields = fields;
2568-
Py_INCREF(fields);
2569-
Py_XDECREF(self->names);
2570-
self->names = names;
2571-
if (incref_names) {
2572-
Py_INCREF(names);
2578+
/*
2579+
* Ensure names are of appropriate string type
2580+
*/
2581+
Py_ssize_t i;
2582+
int names_ok = 1;
2583+
PyObject *name;
2584+
2585+
for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
2586+
name = PyTuple_GET_ITEM(names, i);
2587+
if (!PyUString_Check(name)) {
2588+
names_ok = 0;
2589+
break;
2590+
}
2591+
}
2592+
2593+
if (names_ok) {
2594+
Py_XDECREF(self->fields);
2595+
self->fields = fields;
2596+
Py_INCREF(fields);
2597+
Py_XDECREF(self->names);
2598+
self->names = names;
2599+
if (incref_names) {
2600+
Py_INCREF(names);
2601+
}
2602+
}
2603+
else {
2604+
#if defined(NPY_PY3K)
2605+
/*
2606+
* To support pickle.load(f, encoding='bytes') for loading Py2
2607+
* generated pickles on Py3, we need to be more lenient and convert
2608+
* field names from byte strings to unicode.
2609+
*/
2610+
PyObject *tmp, *new_name, *field;
2611+
2612+
tmp = PyDict_New();
2613+
if (tmp == NULL) {
2614+
return NULL;
2615+
}
2616+
Py_XDECREF(self->fields);
2617+
self->fields = tmp;
2618+
2619+
tmp = PyTuple_New(PyTuple_GET_SIZE(names));
2620+
if (tmp == NULL) {
2621+
return NULL;
2622+
}
2623+
Py_XDECREF(self->names);
2624+
self->names = tmp;
2625+
2626+
for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
2627+
name = PyTuple_GET_ITEM(names, i);
2628+
field = PyDict_GetItem(fields, name);
2629+
if (!field) {
2630+
return NULL;
2631+
}
2632+
2633+
if (PyUnicode_Check(name)) {
2634+
new_name = name;
2635+
Py_INCREF(new_name);
2636+
}
2637+
else {
2638+
new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict");
2639+
if (new_name == NULL) {
2640+
return NULL;
2641+
}
2642+
}
2643+
2644+
PyTuple_SET_ITEM(self->names, i, new_name);
2645+
if (PyDict_SetItem(self->fields, new_name, field) != 0) {
2646+
return NULL;
2647+
}
2648+
}
2649+
#else
2650+
PyErr_Format(PyExc_ValueError,
2651+
"non-string names in Numpy dtype unpickling");
2652+
return NULL;
2653+
#endif
25732654
}
25742655
}
25752656

numpy/core/tests/test_regression.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,41 @@ def __getitem__(self, key):
398398

399399
assert_raises(KeyError, np.lexsort, BuggySequence())
400400

401+
def test_pickle_py2_bytes_encoding(self):
402+
# Check that arrays and scalars pickled on Py2 are
403+
# unpickleable on Py3 using encoding='bytes'
404+
405+
test_data = [
406+
# (original, py2_pickle)
407+
(np.unicode_('\u6f2c'),
408+
asbytes("cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n"
409+
"(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\n"
410+
"I0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n.")),
411+
412+
(np.array([9e123], dtype=np.float64),
413+
asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\n"
414+
"p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n"
415+
"p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n"
416+
"I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")),
417+
418+
(np.array([(9e123,)], dtype=[('name', float)]),
419+
asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n"
420+
"(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n"
421+
"(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n"
422+
"(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n"
423+
"I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n"
424+
"bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")),
425+
]
426+
427+
if sys.version_info[:2] >= (3, 4):
428+
# encoding='bytes' was added in Py3.4
429+
for original, data in test_data:
430+
result = pickle.loads(data, encoding='bytes')
431+
assert_equal(result, original)
432+
433+
if isinstance(result, np.ndarray) and result.dtype.names:
434+
for name in result.dtype.names:
435+
assert_(isinstance(name, str))
401436

402437
def test_pickle_dtype(self,level=rlevel):
403438
"""Ticket #251"""

0 commit comments

Comments
 (0)
0