diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index 8b55c9fbd79f..3d93e801a256 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -2369,11 +2369,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) { int elsize = -1, alignment = -1; int version = 4; -#if defined(NPY_PY3K) - int endian; -#else char endian; -#endif + PyObject *endian_obj; PyObject *subarray, *fields, *names = NULL, *metadata=NULL; int incref_names = 1; int int_dtypeflags = 0; @@ -2390,68 +2387,39 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) } switch (PyTuple_GET_SIZE(PyTuple_GET_ITEM(args,0))) { case 9: -#if defined(NPY_PY3K) -#define _ARGSTR_ "(iCOOOiiiO)" -#else -#define _ARGSTR_ "(icOOOiiiO)" -#endif - if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian, + if (!PyArg_ParseTuple(args, "(iOOOOiiiO)", &version, &endian_obj, &subarray, &names, &fields, &elsize, &alignment, &int_dtypeflags, &metadata)) { + PyErr_Clear(); return NULL; -#undef _ARGSTR_ } break; case 8: -#if defined(NPY_PY3K) -#define _ARGSTR_ "(iCOOOiii)" -#else -#define _ARGSTR_ "(icOOOiii)" -#endif - if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian, + if (!PyArg_ParseTuple(args, "(iOOOOiii)", &version, &endian_obj, &subarray, &names, &fields, &elsize, &alignment, &int_dtypeflags)) { return NULL; -#undef _ARGSTR_ } break; case 7: -#if defined(NPY_PY3K) -#define _ARGSTR_ "(iCOOOii)" -#else -#define _ARGSTR_ "(icOOOii)" -#endif - if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian, + if (!PyArg_ParseTuple(args, "(iOOOOii)", &version, &endian_obj, &subarray, &names, &fields, &elsize, &alignment)) { return NULL; -#undef _ARGSTR_ } break; case 6: -#if defined(NPY_PY3K) -#define _ARGSTR_ "(iCOOii)" -#else -#define _ARGSTR_ "(icOOii)" -#endif - if (!PyArg_ParseTuple(args, _ARGSTR_, &version, - &endian, &subarray, &fields, + if (!PyArg_ParseTuple(args, "(iOOOii)", &version, + &endian_obj, &subarray, &fields, &elsize, &alignment)) { - PyErr_Clear(); -#undef _ARGSTR_ + return NULL; } break; case 5: version = 0; -#if defined(NPY_PY3K) -#define _ARGSTR_ "(COOii)" -#else -#define _ARGSTR_ "(cOOii)" -#endif - if (!PyArg_ParseTuple(args, _ARGSTR_, - &endian, &subarray, &fields, &elsize, + if (!PyArg_ParseTuple(args, "(OOOii)", + &endian_obj, &subarray, &fields, &elsize, &alignment)) { -#undef _ARGSTR_ return NULL; } break; @@ -2494,6 +2462,38 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args) } } + /* Parse endian */ + if (PyUnicode_Check(endian_obj) || PyBytes_Check(endian_obj)) { + PyObject *tmp = NULL; + char *str; + Py_ssize_t len; + + if (PyUnicode_Check(endian_obj)) { + tmp = PyUnicode_AsASCIIString(endian_obj); + if (tmp == NULL) { + return NULL; + } + endian_obj = tmp; + } + + if (PyBytes_AsStringAndSize(endian_obj, &str, &len) == -1) { + Py_XDECREF(tmp); + return NULL; + } + if (len != 1) { + PyErr_SetString(PyExc_ValueError, + "endian is not 1-char string in Numpy dtype unpickling"); + Py_XDECREF(tmp); + return NULL; + } + endian = str[0]; + Py_XDECREF(tmp); + } + else { + PyErr_SetString(PyExc_ValueError, + "endian is not a string in Numpy dtype unpickling"); + return NULL; + } if ((fields == Py_None && names != Py_None) || (names == Py_None && fields != Py_None)) { diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index ed187ce15737..a83713a7562f 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -398,6 +398,29 @@ def __getitem__(self, key): assert_raises(KeyError, np.lexsort, BuggySequence()) + def test_pickle_py2_bytes_encoding(self): + # Check that arrays and scalars pickled on Py2 are + # unpickleable on Py3 using encoding='bytes' + + test_data = [ + # (original, py2_pickle) + (np.unicode_('\u6f2c'), + asbytes("cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n" + "(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\n" + "I0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n.")), + + (np.array([9e123], dtype=np.float64), + asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\n" + "p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n" + "p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n" + "I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")), + ] + + if sys.version_info[:2] >= (3, 4): + # encoding='bytes' was added in Py3.4 + for original, data in test_data: + result = pickle.loads(data, encoding='bytes') + assert_equal(result, original) def test_pickle_dtype(self,level=rlevel): """Ticket #251"""