Skip to content

Commit

Permalink
Merge pull request BVLC#6286 from Noiredd/bilinear-filler-fix
Browse files Browse the repository at this point in the history
BilinearFiller tests refactored
  • Loading branch information
Noiredd authored Mar 12, 2018
2 parents 69da2cf + eb62919 commit f049522
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,21 +500,25 @@ TYPED_TEST(MSRAFillerTest, TestFill5D) {
template <typename Dtype>
class BilinearFillerTest : public ::testing::Test {
protected:
BilinearFillerTest() : filler_param_() {}
virtual void test_params(const int n) {
this->blob_ = new Blob<Dtype>(1000, 2, n, n);
this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int outer_num = this->blob_->count(0, 2);
const int inner_num = this->blob_->count(2, 4);
const Dtype* data = this->blob_->cpu_data();
int f = ceil(this->blob_->width() / 2.);
Dtype c = (this->blob_->width() - 1) / (2. * f);
BilinearFillerTest()
: blob_(new Blob<Dtype>()),
filler_param_() {
}
virtual void test_params(const vector<int>& shape) {
EXPECT_TRUE(blob_);
blob_->Reshape(shape);
filler_.reset(new BilinearFiller<Dtype>(filler_param_));
filler_->Fill(blob_);
CHECK_EQ(blob_->num_axes(), 4);
const int outer_num = blob_->count(0, 2);
const int inner_num = blob_->count(2, 4);
const Dtype* data = blob_->cpu_data();
int f = ceil(blob_->shape(3) / 2.);
Dtype c = (blob_->shape(3) - 1) / (2. * f);
for (int i = 0; i < outer_num; ++i) {
for (int j = 0; j < inner_num; ++j) {
Dtype x = j % this->blob_->width();
Dtype y = (j / this->blob_->width()) % this->blob_->height();
Dtype x = j % blob_->shape(3);
Dtype y = (j / blob_->shape(3)) % blob_->shape(2);
Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
const Dtype actual_value = data[i * inner_num + j];
EXPECT_NEAR(expected_value, actual_value, 0.01);
Expand All @@ -531,11 +535,21 @@ TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);

TYPED_TEST(BilinearFillerTest, TestFillOdd) {
const int n = 7;
this->test_params(n);
vector<int> blob_shape;
blob_shape.push_back(1000);
blob_shape.push_back(2);
blob_shape.push_back(n);
blob_shape.push_back(n);
this->test_params(blob_shape);
}
TYPED_TEST(BilinearFillerTest, TestFillEven) {
const int n = 6;
this->test_params(n);
vector<int> blob_shape;
blob_shape.push_back(1000);
blob_shape.push_back(2);
blob_shape.push_back(n);
blob_shape.push_back(n);
this->test_params(blob_shape);
}

} // namespace caffe

0 comments on commit f049522

Please sign in to comment.