Skip to content

Commit

Permalink
Exposing layer top and bottom names to python
Browse files Browse the repository at this point in the history
  • Loading branch information
philkr committed Jan 5, 2016
1 parent 708c1a1 commit 1137e89
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ class Net {
inline const vector<vector<Blob<Dtype>*> >& top_vecs() const {
return top_vecs_;
}
/// @brief returns the ids of the top blobs of layer i
inline const vector<int> & top_ids(int i) const {
CHECK_GE(i, 0) << "Invalid layer id";
CHECK_LT(i, top_id_vecs_.size()) << "Invalid layer id";
return top_id_vecs_[i];
}
/// @brief returns the ids of the bottom blobs of layer i
inline const vector<int> & bottom_ids(int i) const {
CHECK_GE(i, 0) << "Invalid layer id";
CHECK_LT(i, bottom_id_vecs_.size()) << "Invalid layer id";
return bottom_id_vecs_[i];
}
inline const vector<vector<bool> >& bottom_need_backward() const {
return bottom_need_backward_;
}
Expand Down
4 changes: 4 additions & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
.add_property("_blob_loss_weights", bp::make_function(
&Net<Dtype>::blob_loss_weights, bp::return_internal_reference<>()))
.def("_bottom_ids", bp::make_function(&Net<Dtype>::bottom_ids,
bp::return_value_policy<bp::copy_const_reference>()))
.def("_top_ids", bp::make_function(&Net<Dtype>::top_ids,
bp::return_value_policy<bp::copy_const_reference>()))
.add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
bp::return_internal_reference<>()))
.add_property("layers", bp::make_function(&Net<Dtype>::layers,
Expand Down
18 changes: 18 additions & 0 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,22 @@ def _Net_batch(self, blobs):
padding])
yield padded_batch


class _Net_IdNameWrapper:
"""
A simple wrapper that allows the ids propery to be accessed as a dict
indexed by names. Used for top and bottom names
"""
def __init__(self, net, func):
self.net, self.func = net, func

def __getitem__(self, name):
# Map the layer name to id
ids = self.func(self.net, list(self.net._layer_names).index(name))
# Map the blob id to name
id_to_name = list(self.net.blobs)
return [id_to_name[i] for i in ids]

# Attach methods to Net.
Net.blobs = _Net_blobs
Net.blob_loss_weights = _Net_blob_loss_weights
Expand All @@ -288,3 +304,5 @@ def _Net_batch(self, blobs):
Net._batch = _Net_batch
Net.inputs = _Net_inputs
Net.outputs = _Net_outputs
Net.top_names = property(lambda n: _Net_IdNameWrapper(n, Net._top_ids))
Net.bottom_names = property(lambda n: _Net_IdNameWrapper(n, Net._bottom_ids))

0 comments on commit 1137e89

Please sign in to comment.