Skip to content

Commit

Permalink
Allow old-style shape in blobproto_to_array
Browse files Browse the repository at this point in the history
Fixes #3199
Bug introduced in #3170
  • Loading branch information
lukeyeager committed Oct 15, 2015
1 parent 8c8e832 commit 75e859a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
11 changes: 9 additions & 2 deletions python/caffe/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@ def blobproto_to_array(blob, return_diff=False):
Convert a blob proto to an array. In default, we will just return the data,
unless return_diff is True, in which case we will return the diff.
"""
# Read the data into an array
if return_diff:
return np.array(blob.diff).reshape(*blob.shape.dim)
data = np.array(blob.diff)
else:
return np.array(blob.data).reshape(*blob.shape.dim)
data = np.array(blob.data)

# Reshape the array
if blob.HasField('num') or blob.HasField('channels') or blob.HasField('height') or blob.HasField('width'):
# Use legacy 4D shape
return data.reshape(blob.num, blob.channels, blob.height, blob.width)
else:
return data.reshape(blob.shape.dim)

def array_to_blobproto(arr, diff=None):
"""Converts a N-dimensional array to blob proto. If diff is given, also
Expand Down
41 changes: 41 additions & 0 deletions python/caffe/test/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import unittest

import caffe

class TestBlobProtoToArray(unittest.TestCase):

def test_old_format(self):
data = np.zeros((10,10))
blob = caffe.proto.caffe_pb2.BlobProto()
blob.data.extend(list(data.flatten()))
shape = (1,1,10,10)
blob.num, blob.channels, blob.height, blob.width = shape

arr = caffe.io.blobproto_to_array(blob)
self.assertEqual(arr.shape, shape)

def test_new_format(self):
data = np.zeros((10,10))
blob = caffe.proto.caffe_pb2.BlobProto()
blob.data.extend(list(data.flatten()))
blob.shape.dim.extend(list(data.shape))

arr = caffe.io.blobproto_to_array(blob)
self.assertEqual(arr.shape, data.shape)

def test_no_shape(self):
data = np.zeros((10,10))
blob = caffe.proto.caffe_pb2.BlobProto()
blob.data.extend(list(data.flatten()))

with self.assertRaises(ValueError):
caffe.io.blobproto_to_array(blob)

def test_scalar(self):
data = np.ones((1)) * 123
blob = caffe.proto.caffe_pb2.BlobProto()
blob.data.extend(list(data.flatten()))

arr = caffe.io.blobproto_to_array(blob)
self.assertEqual(arr, 123)

0 comments on commit 75e859a

Please sign in to comment.