Skip to content

Commit

Permalink
ENH: core: make unpickling with encoding='bytes' work
Browse files Browse the repository at this point in the history
Make dtype.__setstate__ accept endian either as a byte string or unicode.

Also fix a missing error return introduced in c355157, apparently
mistake.
  • Loading branch information
pv committed Jul 22, 2014
1 parent fa0ec11 commit 4008bb4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 42 deletions.
84 changes: 42 additions & 42 deletions numpy/core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up @@ -2369,11 +2369,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
{
int elsize = -1, alignment = -1;
int version = 4;
#if defined(NPY_PY3K)
int endian;
#else
char endian;
#endif
PyObject *endian_obj;
PyObject *subarray, *fields, *names = NULL, *metadata=NULL;
int incref_names = 1;
int int_dtypeflags = 0;
Expand All @@ -2390,68 +2387,39 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
switch (PyTuple_GET_SIZE(PyTuple_GET_ITEM(args,0))) {
case 9:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOiiiO)"
#else
#define _ARGSTR_ "(icOOOiiiO)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOiiiO)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags, &metadata)) {
PyErr_Clear();
return NULL;
#undef _ARGSTR_
}
break;
case 8:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOiii)"
#else
#define _ARGSTR_ "(icOOOiii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOiii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags)) {
return NULL;
#undef _ARGSTR_
}
break;
case 7:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOii)"
#else
#define _ARGSTR_ "(icOOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment)) {
return NULL;
#undef _ARGSTR_
}
break;
case 6:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOii)"
#else
#define _ARGSTR_ "(icOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version,
&endian, &subarray, &fields,
if (!PyArg_ParseTuple(args, "(iOOOii)", &version,
&endian_obj, &subarray, &fields,
&elsize, &alignment)) {
PyErr_Clear();
#undef _ARGSTR_
return NULL;
}
break;
case 5:
version = 0;
#if defined(NPY_PY3K)
#define _ARGSTR_ "(COOii)"
#else
#define _ARGSTR_ "(cOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_,
&endian, &subarray, &fields, &elsize,
if (!PyArg_ParseTuple(args, "(OOOii)",
&endian_obj, &subarray, &fields, &elsize,
&alignment)) {
#undef _ARGSTR_
return NULL;
}
break;
Expand Down Expand Up @@ -2494,6 +2462,38 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
}

/* Parse endian */
if (PyUnicode_Check(endian_obj) || PyBytes_Check(endian_obj)) {
PyObject *tmp = NULL;
char *str;
Py_ssize_t len;

if (PyUnicode_Check(endian_obj)) {
tmp = PyUnicode_AsASCIIString(endian_obj);
if (tmp == NULL) {
return NULL;
}
endian_obj = tmp;
}

if (PyBytes_AsStringAndSize(endian_obj, &str, &len) == -1) {
Py_XDECREF(tmp);
return NULL;
}
if (len != 1) {
PyErr_SetString(PyExc_ValueError,
"endian is not 1-char string in Numpy dtype unpickling");
Py_XDECREF(tmp);
return NULL;
}
endian = str[0];
Py_XDECREF(tmp);
}
else {
PyErr_SetString(PyExc_ValueError,
"endian is not a string in Numpy dtype unpickling");
return NULL;
}

if ((fields == Py_None && names != Py_None) ||
(names == Py_None && fields != Py_None)) {
Expand Down
23 changes: 23 additions & 0 deletions numpy/core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,29 @@ def __getitem__(self, key):

assert_raises(KeyError, np.lexsort, BuggySequence())

def test_pickle_py2_bytes_encoding(self):
# Check that arrays and scalars pickled on Py2 are
# unpickleable on Py3 using encoding='bytes'

test_data = [
# (original, py2_pickle)
(np.unicode_('\u6f2c'),
asbytes("cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n"
"(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\n"
"I0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n.")),

(np.array([9e123], dtype=np.float64),
asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\n"
"p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n"
"p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n"
"I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")),
]

if sys.version_info[:2] >= (3, 4):
# encoding='bytes' was added in Py3.4
for original, data in test_data:
result = pickle.loads(data, encoding='bytes')
assert_equal(result, original)

def test_pickle_dtype(self,level=rlevel):
"""Ticket #251"""
Expand Down

0 comments on commit 4008bb4

Please sign in to comment.