Skip to content

Commit

Permalink
BUG: core: force dtype subarray->shape to be always a tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
pv committed Oct 31, 2010
1 parent 5012504 commit b124297
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
68 changes: 62 additions & 6 deletions numpy/core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,24 @@ _convert_from_tuple(PyObject *obj)
newdescr->elsize *= PyArray_MultiplyList(shape.ptr, shape.len);
PyDimMem_FREE(shape.ptr);
newdescr->subarray = _pya_malloc(sizeof(PyArray_ArrayDescr));
newdescr->subarray->base = type;
newdescr->flags = type->flags;
Py_INCREF(val);
newdescr->subarray->shape = val;
newdescr->subarray->base = type;
type = NULL;
Py_XDECREF(newdescr->fields);
Py_XDECREF(newdescr->names);
newdescr->fields = NULL;
newdescr->names = NULL;
/* Force subarray->shape to always be a tuple */
if (PyTuple_Check(val)) {
Py_INCREF(val);
newdescr->subarray->shape = val;
} else {
newdescr->subarray->shape = Py_BuildValue("(O)", val);
if (newdescr->subarray->shape == NULL) {
Py_DECREF(newdescr);
goto fail;
}
}
type = newdescr;
}
return type;
Expand Down Expand Up @@ -1499,7 +1509,7 @@ arraydescr_dealloc(PyArray_Descr *self)
Py_XDECREF(self->names);
Py_XDECREF(self->fields);
if (self->subarray) {
Py_DECREF(self->subarray->shape);
Py_XDECREF(self->subarray->shape);
Py_DECREF(self->subarray->base);
_pya_free(self->subarray);
}
Expand Down Expand Up @@ -1672,6 +1682,10 @@ arraydescr_shape_get(PyArray_Descr *self)
if (self->subarray == NULL) {
return PyTuple_New(0);
}
/*TODO
* self->subarray->shape should always be a tuple,
* so this check should be unnecessary
*/
if (PyTuple_Check(self->subarray->shape)) {
Py_INCREF(self->subarray->shape);
return (PyObject *)(self->subarray->shape);
Expand Down Expand Up @@ -2351,11 +2365,49 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
self->subarray = NULL;

if (subarray != Py_None) {
PyObject *subarray_shape;

/*
* Ensure that subarray[0] is an ArrayDescr and
* that subarray_shape obtained from subarray[1] is a tuple of integers.
*/
if (!(PyTuple_Check(subarray) &&
PyTuple_Size(subarray) == 2 &&
PyArray_DescrCheck(PyTuple_GET_ITEM(subarray, 0)))) {
PyErr_Format(PyExc_ValueError,
"incorrect subarray in __setstate__");
return NULL;
}
subarray_shape = PyTuple_GET_ITEM(subarray, 1);
if (PyNumber_Check(subarray_shape)) {
PyObject *tmp;
#if defined(NPY_PY3K)
tmp = PyNumber_Long(subarray_shape);
#else
tmp = PyNumber_Int(subarray_shape);
#endif
if (tmp == NULL) {
return NULL;
}
subarray_shape = Py_BuildValue("(O)", tmp);
Py_DECREF(tmp);
if (subarray_shape == NULL) {
return NULL;
}
}
else if (_is_tuple_of_integers(subarray_shape)) {
Py_INCREF(subarray_shape);
}
else {
PyErr_Format(PyExc_ValueError,
"incorrect subarray shape in __setstate__");
return NULL;
}

self->subarray = _pya_malloc(sizeof(PyArray_ArrayDescr));
self->subarray->base = (PyArray_Descr *)PyTuple_GET_ITEM(subarray, 0);
Py_INCREF(self->subarray->base);
self->subarray->shape = PyTuple_GET_ITEM(subarray, 1);
Py_INCREF(self->subarray->shape);
self->subarray->shape = subarray_shape;
}

if (fields != Py_None) {
Expand Down Expand Up @@ -2648,6 +2700,10 @@ arraydescr_str(PyArray_Descr *self)
}
PyUString_ConcatAndDel(&t, p);
PyUString_ConcatAndDel(&t, PyUString_FromString(","));
/*TODO
* self->subarray->shape should always be a tuple,
* so this check should be unnecessary
*/
if (!PyTuple_Check(self->subarray->shape)) {
sh = Py_BuildValue("(O)", self->subarray->shape);
}
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/tests/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def test_single_subarray(self):
self.assertTrue(hash(a) == hash(b),
"two equivalent types do not hash to the same value !")

assert_equal(type(a.subdtype[1]), tuple)
assert_equal(type(b.subdtype[1]), tuple)

def test_equivalent_record(self):
"""Test whether equivalent subarray dtypes hash the same."""
a = np.dtype((np.int, (2, 3)))
Expand Down
6 changes: 6 additions & 0 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ def test_version1_object(self):
p = self._loads(asbytes(s))
assert_equal(a, p)

def test_subarray_int_shape(self):
s = "cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n(S'V6'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'a'\np12\ng3\ntp13\n(dp14\ng12\n(g7\n(S'V4'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'|'\np18\n(g7\n(S'i1'\np19\nI0\nI1\ntp20\nRp21\n(I3\nS'|'\np22\nNNNI-1\nI-1\nI0\ntp23\nb(I2\nI2\ntp24\ntp25\nNNI4\nI1\nI0\ntp26\nbI0\ntp27\nsg3\n(g7\n(S'V2'\np28\nI0\nI1\ntp29\nRp30\n(I3\nS'|'\np31\n(g21\nI2\ntp32\nNNI2\nI1\nI0\ntp33\nbI4\ntp34\nsI6\nI1\nI0\ntp35\nbI00\nS'\\x01\\x01\\x01\\x01\\x01\\x02'\np36\ntp37\nb."
a = np.array([(1,(1,2))], dtype=[('a', 'i1', (2,2)), ('b', 'i1', 2)])
p = self._loads(asbytes(s))
assert_equal(a, p)


class TestFancyIndexing(TestCase):
def test_list(self):
Expand Down

0 comments on commit b124297

Please sign in to comment.