Skip to content

Commit

Permalink
TestConcatLayer: add forward/gradient tests for concatenation along num
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Mar 3, 2015
1 parent 704e524 commit d52e9a8
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions src/caffe/test/test_concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,36 @@ TYPED_TEST(ConcatLayerTest, TestSetupChannels) {
EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_0_->width());
}

TYPED_TEST(ConcatLayerTest, TestForwardNum) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_concat_param()->set_concat_dim(0);
ConcatLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_1_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_1_, this->blob_top_vec_);
for (int n = 0; n < this->blob_bottom_vec_1_[0]->num(); ++n) {
for (int c = 0; c < this->blob_top_->channels(); ++c) {
for (int h = 0; h < this->blob_top_->height(); ++h) {
for (int w = 0; w < this->blob_top_->width(); ++w) {
EXPECT_EQ(this->blob_top_->data_at(n, c, h, w),
this->blob_bottom_vec_1_[0]->data_at(n, c, h, w));
}
}
}
}
for (int n = 0; n < this->blob_bottom_vec_1_[1]->num(); ++n) {
for (int c = 0; c < this->blob_top_->channels(); ++c) {
for (int h = 0; h < this->blob_top_->height(); ++h) {
for (int w = 0; w < this->blob_top_->width(); ++w) {
EXPECT_EQ(this->blob_top_->data_at(n + 2, c, h, w),
this->blob_bottom_vec_1_[1]->data_at(n, c, h, w));
}
}
}
}
}

TYPED_TEST(ConcatLayerTest, TestNum) {
TYPED_TEST(ConcatLayerTest, TestForwardChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ConcatLayer<Dtype> layer(layer_param);
Expand All @@ -110,7 +138,17 @@ TYPED_TEST(ConcatLayerTest, TestNum) {
}
}

TYPED_TEST(ConcatLayerTest, TestGradient) {
TYPED_TEST(ConcatLayerTest, TestGradientNum) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_concat_param()->set_concat_dim(0);
ConcatLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
checker.CheckGradient(&layer, this->blob_bottom_vec_1_,
this->blob_top_vec_);
}

TYPED_TEST(ConcatLayerTest, TestGradientChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ConcatLayer<Dtype> layer(layer_param);
Expand Down

0 comments on commit d52e9a8

Please sign in to comment.