From eb442b9bc9ca206bd0606f259115f01b53144e7a Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 31 Dec 2014 18:02:12 -0800 Subject: [PATCH] FlattenLayer gets a FlattenParameter with an axis, end_axis --- src/caffe/layers/flatten_layer.cpp | 16 ++++++++-- src/caffe/proto/caffe.proto | 16 +++++++++- src/caffe/test/test_flatten_layer.cpp | 46 +++++++++++++++++++++++---- 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp index 745f271ea..f7e5c9c21 100644 --- a/src/caffe/layers/flatten_layer.cpp +++ b/src/caffe/layers/flatten_layer.cpp @@ -9,9 +9,19 @@ namespace caffe { template void FlattenLayer::Reshape(const vector*>& bottom, const vector*>& top) { - vector top_shape(2); - top_shape[0] = bottom[0]->num(); - top_shape[1] = bottom[0]->count() / bottom[0]->num(); + const int start_axis = bottom[0]->CanonicalAxisIndex( + this->layer_param_.flatten_param().axis()); + const int end_axis = bottom[0]->CanonicalAxisIndex( + this->layer_param_.flatten_param().end_axis()); + vector top_shape; + for (int i = 0; i < start_axis; ++i) { + top_shape.push_back(bottom[0]->shape(i)); + } + const int flattened_dim = bottom[0]->count(start_axis, end_axis + 1); + top_shape.push_back(flattened_dim); + for (int i = end_axis + 1; i < bottom[0]->num_axes(); ++i) { + top_shape.push_back(bottom[0]->shape(i)); + } top[0]->Reshape(top_shape); CHECK_EQ(top[0]->count(), bottom[0]->count()); } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 619642f2d..f79cf80cc 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -269,7 +269,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 135 (last added: log_param) +// LayerParameter next available layer-specific ID: 136 (last added: flatten_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -326,6 +326,7 @@ message LayerParameter { optional DummyDataParameter dummy_data_param = 109; optional EltwiseParameter eltwise_param = 110; optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; optional HDF5DataParameter hdf5_data_param = 112; optional HDF5OutputParameter hdf5_output_param = 113; optional HingeLossParameter hinge_loss_param = 114; @@ -533,6 +534,19 @@ message ExpParameter { optional float shift = 3 [default = 0.0]; } +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer message HDF5DataParameter { // Specify the data source. optional string source = 1; diff --git a/src/caffe/test/test_flatten_layer.cpp b/src/caffe/test/test_flatten_layer.cpp index 3042d293c..7b6757cba 100644 --- a/src/caffe/test/test_flatten_layer.cpp +++ b/src/caffe/test/test_flatten_layer.cpp @@ -42,13 +42,48 @@ TYPED_TEST(FlattenLayerTest, TestSetup) { LayerParameter layer_param; FlattenLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); - EXPECT_EQ(this->blob_top_->num(), 2); - EXPECT_EQ(this->blob_top_->channels(), 3 * 6 * 5); - EXPECT_EQ(this->blob_top_->height(), 1); - EXPECT_EQ(this->blob_top_->width(), 1); + ASSERT_EQ(this->blob_top_->num_axes(), 2); + EXPECT_EQ(this->blob_top_->shape(0), 2); + EXPECT_EQ(this->blob_top_->shape(1), 3 * 6 * 5); } -TYPED_TEST(FlattenLayerTest, Test) { +TYPED_TEST(FlattenLayerTest, TestSetupWithAxis) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_flatten_param()->set_axis(2); + FlattenLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), 3); + EXPECT_EQ(this->blob_top_->shape(0), 2); + EXPECT_EQ(this->blob_top_->shape(1), 3); + EXPECT_EQ(this->blob_top_->shape(2), 6 * 5); +} + +TYPED_TEST(FlattenLayerTest, TestSetupWithEndAxis) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_flatten_param()->set_end_axis(-2); + FlattenLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), 3); + EXPECT_EQ(this->blob_top_->shape(0), 2); + EXPECT_EQ(this->blob_top_->shape(1), 3 * 6); + EXPECT_EQ(this->blob_top_->shape(2), 5); +} + +TYPED_TEST(FlattenLayerTest, TestSetupWithStartAndEndAxis) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_flatten_param()->set_axis(0); + layer_param.mutable_flatten_param()->set_end_axis(-2); + FlattenLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), 2); + EXPECT_EQ(this->blob_top_->shape(0), 2 * 3 * 6); + EXPECT_EQ(this->blob_top_->shape(1), 5); +} + +TYPED_TEST(FlattenLayerTest, TestForward) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; FlattenLayer layer(layer_param); @@ -71,5 +106,4 @@ TYPED_TEST(FlattenLayerTest, TestGradient) { this->blob_top_vec_); } - } // namespace caffe