diff --git a/python/caffe/io.py b/python/caffe/io.py index 40b7ac1ed..11c84260f 100644 --- a/python/caffe/io.py +++ b/python/caffe/io.py @@ -20,11 +20,18 @@ def blobproto_to_array(blob, return_diff=False): Convert a blob proto to an array. In default, we will just return the data, unless return_diff is True, in which case we will return the diff. """ + # Read the data into an array if return_diff: - return np.array(blob.diff).reshape(*blob.shape.dim) + data = np.array(blob.diff) else: - return np.array(blob.data).reshape(*blob.shape.dim) + data = np.array(blob.data) + # Reshape the array + if blob.HasField('num') or blob.HasField('channels') or blob.HasField('height') or blob.HasField('width'): + # Use legacy 4D shape + return data.reshape(blob.num, blob.channels, blob.height, blob.width) + else: + return data.reshape(blob.shape.dim) def array_to_blobproto(arr, diff=None): """Converts a N-dimensional array to blob proto. If diff is given, also diff --git a/python/caffe/test/test_io.py b/python/caffe/test/test_io.py new file mode 100644 index 000000000..8c86ef75f --- /dev/null +++ b/python/caffe/test/test_io.py @@ -0,0 +1,41 @@ +import numpy as np +import unittest + +import caffe + +class TestBlobProtoToArray(unittest.TestCase): + + def test_old_format(self): + data = np.zeros((10,10)) + blob = caffe.proto.caffe_pb2.BlobProto() + blob.data.extend(list(data.flatten())) + shape = (1,1,10,10) + blob.num, blob.channels, blob.height, blob.width = shape + + arr = caffe.io.blobproto_to_array(blob) + self.assertEqual(arr.shape, shape) + + def test_new_format(self): + data = np.zeros((10,10)) + blob = caffe.proto.caffe_pb2.BlobProto() + blob.data.extend(list(data.flatten())) + blob.shape.dim.extend(list(data.shape)) + + arr = caffe.io.blobproto_to_array(blob) + self.assertEqual(arr.shape, data.shape) + + def test_no_shape(self): + data = np.zeros((10,10)) + blob = caffe.proto.caffe_pb2.BlobProto() + blob.data.extend(list(data.flatten())) + + with self.assertRaises(ValueError): + caffe.io.blobproto_to_array(blob) + + def test_scalar(self): + data = np.ones((1)) * 123 + blob = caffe.proto.caffe_pb2.BlobProto() + blob.data.extend(list(data.flatten())) + + arr = caffe.io.blobproto_to_array(blob) + self.assertEqual(arr, 123)