Skip to content

Commit

Permalink
Python bindings: fix MDArray.ReadAsArray() with CInt16/CInt32 data type
Browse files Browse the repository at this point in the history
  • Loading branch information
rouault committed Apr 27, 2021
1 parent 7dd9d71 commit a514e84
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 18 deletions.
32 changes: 32 additions & 0 deletions autotest/gcore/numpy_rw_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,35 @@ def test_numpy_rw_multidim_compound_datatype():
assert myarray.WriteArray(ar) == gdal.CE_None
res = myarray.ReadAsArray()
assert np.array_equal(res, ar)

###############################################################################


@pytest.mark.parametrize("datatype", [gdal.GDT_Byte,
gdal.GDT_Int16,
gdal.GDT_UInt16,
gdal.GDT_Int32,
gdal.GDT_UInt32,
gdal.GDT_Float32,
gdal.GDT_Float64,
gdal.GDT_CInt16,
gdal.GDT_CInt32,
gdal.GDT_CFloat32,
gdal.GDT_CFloat64, ], ids=gdal.GetDataTypeName)
def test_numpy_rw_multidim_datatype(datatype):

if gdaltest.numpy_drv is None:
pytest.skip()
import numpy as np

drv = gdal.GetDriverByName('MEM')
ds = drv.CreateMultiDimensional('myds')
rg = ds.GetRootGroup()
dim = rg.CreateDimension("dim0", None, None, 2)
myarray = rg.CreateMDArray("myarray", [ dim ], gdal.ExtendedDataType.Create(datatype))
assert myarray
numpy_ar = np.reshape(np.arange(0, 2, dtype=np.uint16), (2,))
assert myarray.WriteArray(numpy_ar) == gdal.CE_None
got = myarray.ReadAsArray()
assert np.array_equal(got, numpy_ar)
assert np.array_equal(myarray.ReadAsArray(buf_obj = np.zeros(got.shape, got.dtype)), numpy_ar)
35 changes: 26 additions & 9 deletions gdal/swig/include/gdal_array.i
Original file line number Diff line number Diff line change
Expand Up @@ -1688,32 +1688,42 @@ def BandWriteArray(band, array, xoff=0, yoff=0,
_RaiseException()
return ret

def ExtendedDataTypeToNumPyDataType(dt):
def _ExtendedDataTypeToNumPyDataType(dt):
klass = dt.GetClass()

if klass == gdal.GEDTC_STRING:
return numpy.bytes_
return numpy.bytes_, dt

if klass == gdal.GEDTC_NUMERIC:
buf_type = dt.GetNumericDataType()
typecode = GDALTypeCodeToNumericTypeCode(buf_type)
if typecode is None:
typecode = numpy.float32
return typecode
dt = gdal.ExtendedDataType.Create(gdal.GDT_Float32)
else:
dt = gdal.ExtendedDataType.Create(NumericTypeCodeToGDALTypeCode(typecode))
return typecode, dt

assert klass == gdal.GEDTC_COMPOUND
names = []
formats = []
offsets = []
for comp in dt.GetComponents():
names.append(comp.GetName())
formats.append(ExtendedDataTypeToNumPyDataType(comp.GetType()))
typecode, subdt = _ExtendedDataTypeToNumPyDataType(comp.GetType())
if subdt != comp.GetType():
raise Exception("Incompatible datatype")
formats.append(typecode)
offsets.append(comp.GetOffset())

return numpy.dtype({'names': names,
'formats': formats,
'offsets': offsets,
'itemsize': dt.GetSize()})
'itemsize': dt.GetSize()}), dt

def ExtendedDataTypeToNumPyDataType(dt):
typecode, _ = _ExtendedDataTypeToNumPyDataType(dt)
return typecode

def MDArrayReadAsArray(mdarray,
array_start_idx = None,
Expand All @@ -1727,12 +1737,18 @@ def MDArrayReadAsArray(mdarray,
count = [ dim.GetSize() for dim in mdarray.GetDimensions() ]
if not array_step:
array_step = [1] * mdarray.GetDimensionCount()
if not buffer_datatype:
buffer_datatype = mdarray.GetDataType()

if buf_obj is None:
typecode = ExtendedDataTypeToNumPyDataType(buffer_datatype)
if not buffer_datatype:
buffer_datatype = mdarray.GetDataType()
typecode, buffer_datatype = _ExtendedDataTypeToNumPyDataType(buffer_datatype)
buf_obj = numpy.empty(count, dtype=typecode)
else:
datatype = NumericTypeCodeToGDALTypeCode(buf_obj.dtype.type)
if not datatype:
raise ValueError("array does not have corresponding GDAL data type")

buffer_datatype = gdal.ExtendedDataType.Create(datatype)

ret = MDArrayIONumPy(False, mdarray, buf_obj, array_start_idx, array_step, buffer_datatype)
if ret != 0:
Expand All @@ -1748,7 +1764,8 @@ def MDArrayWriteArray(mdarray, array,
array_step = [1] * mdarray.GetDimensionCount()

buffer_datatype = mdarray.GetDataType()
if array.dtype != ExtendedDataTypeToNumPyDataType(buffer_datatype):
typecode = ExtendedDataTypeToNumPyDataType(buffer_datatype)
if array.dtype != typecode:
datatype = NumericTypeCodeToGDALTypeCode(array.dtype.type)

# if we receive some odd type, like int64, try casting to a very
Expand Down
35 changes: 26 additions & 9 deletions gdal/swig/python/osgeo/gdal_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,32 +533,42 @@ def BandWriteArray(band, array, xoff=0, yoff=0,
_RaiseException()
return ret

def ExtendedDataTypeToNumPyDataType(dt):
def _ExtendedDataTypeToNumPyDataType(dt):
klass = dt.GetClass()

if klass == gdal.GEDTC_STRING:
return numpy.bytes_
return numpy.bytes_, dt

if klass == gdal.GEDTC_NUMERIC:
buf_type = dt.GetNumericDataType()
typecode = GDALTypeCodeToNumericTypeCode(buf_type)
if typecode is None:
typecode = numpy.float32
return typecode
dt = gdal.ExtendedDataType.Create(gdal.GDT_Float32)
else:
dt = gdal.ExtendedDataType.Create(NumericTypeCodeToGDALTypeCode(typecode))
return typecode, dt

assert klass == gdal.GEDTC_COMPOUND
names = []
formats = []
offsets = []
for comp in dt.GetComponents():
names.append(comp.GetName())
formats.append(ExtendedDataTypeToNumPyDataType(comp.GetType()))
typecode, subdt = _ExtendedDataTypeToNumPyDataType(comp.GetType())
if subdt != comp.GetType():
raise Exception("Incompatible datatype")
formats.append(typecode)
offsets.append(comp.GetOffset())

return numpy.dtype({'names': names,
'formats': formats,
'offsets': offsets,
'itemsize': dt.GetSize()})
'itemsize': dt.GetSize()}), dt

def ExtendedDataTypeToNumPyDataType(dt):
typecode, _ = _ExtendedDataTypeToNumPyDataType(dt)
return typecode

def MDArrayReadAsArray(mdarray,
array_start_idx = None,
Expand All @@ -572,12 +582,18 @@ def MDArrayReadAsArray(mdarray,
count = [ dim.GetSize() for dim in mdarray.GetDimensions() ]
if not array_step:
array_step = [1] * mdarray.GetDimensionCount()
if not buffer_datatype:
buffer_datatype = mdarray.GetDataType()

if buf_obj is None:
typecode = ExtendedDataTypeToNumPyDataType(buffer_datatype)
if not buffer_datatype:
buffer_datatype = mdarray.GetDataType()
typecode, buffer_datatype = _ExtendedDataTypeToNumPyDataType(buffer_datatype)
buf_obj = numpy.empty(count, dtype=typecode)
else:
datatype = NumericTypeCodeToGDALTypeCode(buf_obj.dtype.type)
if not datatype:
raise ValueError("array does not have corresponding GDAL data type")

buffer_datatype = gdal.ExtendedDataType.Create(datatype)

ret = MDArrayIONumPy(False, mdarray, buf_obj, array_start_idx, array_step, buffer_datatype)
if ret != 0:
Expand All @@ -593,7 +609,8 @@ def MDArrayWriteArray(mdarray, array,
array_step = [1] * mdarray.GetDimensionCount()

buffer_datatype = mdarray.GetDataType()
if array.dtype != ExtendedDataTypeToNumPyDataType(buffer_datatype):
typecode = ExtendedDataTypeToNumPyDataType(buffer_datatype)
if array.dtype != typecode:
datatype = NumericTypeCodeToGDALTypeCode(array.dtype.type)

# if we receive some odd type, like int64, try casting to a very
Expand Down

0 comments on commit a514e84

Please sign in to comment.