Skip to content

Commit

Permalink
Merge pull request BVLC#5545 from brunobowden/shape_mismatch_checks
Browse files Browse the repository at this point in the history
Shape mismatch CHECK logging improvements
  • Loading branch information
Noiredd authored Feb 2, 2018
2 parents 88c9618 + fb0795c commit d2627e9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/caffe/layers/base_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ void BaseConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
// TODO: generalize to handle inputs of different shapes.
for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
CHECK(bottom[0]->shape() == bottom[bottom_id]->shape())
<< "All inputs must have the same shape.";
<< "shape mismatch - bottom[0]: " << bottom[0]->shape_string()
<< " vs. bottom[" << bottom_id << "]: "
<< bottom[bottom_id]->shape_string();
}
// Shape the tops.
bottom_shape_ = &bottom[0]->shape();
Expand Down
5 changes: 3 additions & 2 deletions src/caffe/layers/recurrent_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ void RecurrentLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const int bottom_offset = 2 + static_input_;
for (int i = bottom_offset, j = 0; i < bottom.size(); ++i, ++j) {
CHECK(recur_input_blobs_[j]->shape() == bottom[i]->shape())
<< "bottom[" << i << "] shape must match hidden state input shape: "
<< recur_input_blobs_[j]->shape_string();
<< "shape mismatch - recur_input_blobs_[" << j << "]: "
<< recur_input_blobs_[j]->shape_string()
<< " vs. bottom[" << i << "]: " << bottom[i]->shape_string();
recur_input_blobs_[j]->ShareData(*bottom[i]);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/caffe/layers/slice_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ void SliceLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
int count = 0;
if (slice_point_.size() != 0) {
CHECK_EQ(slice_point_.size(), top.size() - 1);
CHECK_LE(top.size(), bottom_slice_axis);
CHECK_LE(top.size(), bottom_slice_axis)
<< "slice axis: " << slice_axis_
<< ", bottom[0] shape: " << bottom[0]->shape_string();
int prev = 0;
vector<int> slices;
for (int i = 0; i < slice_point_.size(); ++i) {
Expand Down

0 comments on commit d2627e9

Please sign in to comment.