Skip to content

Commit

Permalink
Merge pull request BVLC#4343 from nitnelave/python/top_names
Browse files Browse the repository at this point in the history
improve top_names and bottom_names in pycaffe
  • Loading branch information
longjon authored Jul 13, 2016
2 parents 2f49fd2 + 7c50a2c commit c2a1ecd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
40 changes: 25 additions & 15 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,21 +292,31 @@ def _Net_batch(self, blobs):
padding])
yield padded_batch


class _Net_IdNameWrapper:
"""
A simple wrapper that allows the ids propery to be accessed as a dict
indexed by names. Used for top and bottom names
def _Net_get_id_name(func, field):
"""
def __init__(self, net, func):
self.net, self.func = net, func
Generic property that maps func to the layer names into an OrderedDict.
Used for top_names and bottom_names.
def __getitem__(self, name):
# Map the layer name to id
ids = self.func(self.net, list(self.net._layer_names).index(name))
# Map the blob id to name
id_to_name = list(self.net.blobs)
return [id_to_name[i] for i in ids]
Parameters
----------
func: function id -> [id]
field: implementation field name (cache)
Returns
------
A one-parameter function that can be set as a property.
"""
@property
def get_id_name(self):
if not hasattr(self, field):
id_to_name = list(self.blobs)
res = OrderedDict([(self._layer_names[i],
[id_to_name[j] for j in func(self, i)])
for i in range(len(self.layers))])
setattr(self, field, res)
return getattr(self, field)
return get_id_name

# Attach methods to Net.
Net.blobs = _Net_blobs
Expand All @@ -320,5 +330,5 @@ def __getitem__(self, name):
Net._batch = _Net_batch
Net.inputs = _Net_inputs
Net.outputs = _Net_outputs
Net.top_names = property(lambda n: _Net_IdNameWrapper(n, Net._top_ids))
Net.bottom_names = property(lambda n: _Net_IdNameWrapper(n, Net._bottom_ids))
Net.top_names = _Net_get_id_name(Net._top_ids, "_top_names")
Net.bottom_names = _Net_get_id_name(Net._bottom_ids, "_bottom_names")
13 changes: 13 additions & 0 deletions python/caffe/test/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import numpy as np
import six
from collections import OrderedDict

import caffe

Expand Down Expand Up @@ -78,6 +79,18 @@ def test_inputs_outputs(self):
self.assertEqual(self.net.inputs, [])
self.assertEqual(self.net.outputs, ['loss'])

def test_top_bottom_names(self):
self.assertEqual(self.net.top_names,
OrderedDict([('data', ['data', 'label']),
('conv', ['conv']),
('ip', ['ip']),
('loss', ['loss'])]))
self.assertEqual(self.net.bottom_names,
OrderedDict([('data', []),
('conv', ['data']),
('ip', ['conv']),
('loss', ['ip', 'label'])]))

def test_save_and_read(self):
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
f.close()
Expand Down

0 comments on commit c2a1ecd

Please sign in to comment.