Skip to content

Commit

Permalink
Merge pull request #2082 from jeffdonahue/flatten-layer-axis
Browse files Browse the repository at this point in the history
FlattenLayer gets a FlattenParameter with an axis, end_axis
  • Loading branch information
jeffdonahue committed Jun 3, 2015
2 parents 2d137e1 + eb442b9 commit f3eabad
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
16 changes: 13 additions & 3 deletions src/caffe/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@ namespace caffe {
template <typename Dtype>
void FlattenLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
vector<int> top_shape(2);
top_shape[0] = bottom[0]->num();
top_shape[1] = bottom[0]->count() / bottom[0]->num();
const int start_axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.flatten_param().axis());
const int end_axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.flatten_param().end_axis());
vector<int> top_shape;
for (int i = 0; i < start_axis; ++i) {
top_shape.push_back(bottom[0]->shape(i));
}
const int flattened_dim = bottom[0]->count(start_axis, end_axis + 1);
top_shape.push_back(flattened_dim);
for (int i = end_axis + 1; i < bottom[0]->num_axes(); ++i) {
top_shape.push_back(bottom[0]->shape(i));
}
top[0]->Reshape(top_shape);
CHECK_EQ(top[0]->count(), bottom[0]->count());
}
Expand Down
16 changes: 15 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 135 (last added: log_param)
// LayerParameter next available layer-specific ID: 136 (last added: flatten_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -326,6 +326,7 @@ message LayerParameter {
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
optional HDF5OutputParameter hdf5_output_param = 113;
optional HingeLossParameter hinge_loss_param = 114;
Expand Down Expand Up @@ -533,6 +534,19 @@ message ExpParameter {
optional float shift = 3 [default = 0.0];
}

/// Message that stores parameters used by FlattenLayer
message FlattenParameter {
// The first axis to flatten: all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 1 [default = 1];

// The last axis to flatten: all following axes are retained in the output.
// May be negative to index from the end (e.g., the default -1 for the last
// axis).
optional int32 end_axis = 2 [default = -1];
}

// Message that stores parameters used by HDF5DataLayer
message HDF5DataParameter {
// Specify the data source.
optional string source = 1;
Expand Down
46 changes: 40 additions & 6 deletions src/caffe/test/test_flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,48 @@ TYPED_TEST(FlattenLayerTest, TestSetup) {
LayerParameter layer_param;
FlattenLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3 * 6 * 5);
EXPECT_EQ(this->blob_top_->height(), 1);
EXPECT_EQ(this->blob_top_->width(), 1);
ASSERT_EQ(this->blob_top_->num_axes(), 2);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3 * 6 * 5);
}

TYPED_TEST(FlattenLayerTest, Test) {
TYPED_TEST(FlattenLayerTest, TestSetupWithAxis) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_flatten_param()->set_axis(2);
FlattenLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
ASSERT_EQ(this->blob_top_->num_axes(), 3);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3);
EXPECT_EQ(this->blob_top_->shape(2), 6 * 5);
}

TYPED_TEST(FlattenLayerTest, TestSetupWithEndAxis) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_flatten_param()->set_end_axis(-2);
FlattenLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
ASSERT_EQ(this->blob_top_->num_axes(), 3);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3 * 6);
EXPECT_EQ(this->blob_top_->shape(2), 5);
}

TYPED_TEST(FlattenLayerTest, TestSetupWithStartAndEndAxis) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_flatten_param()->set_axis(0);
layer_param.mutable_flatten_param()->set_end_axis(-2);
FlattenLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
ASSERT_EQ(this->blob_top_->num_axes(), 2);
EXPECT_EQ(this->blob_top_->shape(0), 2 * 3 * 6);
EXPECT_EQ(this->blob_top_->shape(1), 5);
}

TYPED_TEST(FlattenLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
FlattenLayer<Dtype> layer(layer_param);
Expand All @@ -71,5 +106,4 @@ TYPED_TEST(FlattenLayerTest, TestGradient) {
this->blob_top_vec_);
}


} // namespace caffe

0 comments on commit f3eabad

Please sign in to comment.