diff --git a/include/caffe/layers/swish_layer.hpp b/include/caffe/layers/swish_layer.hpp new file mode 100644 index 00000000000..d538ff6de82 --- /dev/null +++ b/include/caffe/layers/swish_layer.hpp @@ -0,0 +1,96 @@ +#ifndef CAFFE_SWISH_LAYER_HPP_ +#define CAFFE_SWISH_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/neuron_layer.hpp" +#include "caffe/layers/sigmoid_layer.hpp" + +namespace caffe { + +/** + * @brief Swish non-linearity @f$ y = x \sigma (\beta x) @f$. + * A novel activation function that tends to work better than ReLU [1]. + * + * [1] Prajit Ramachandran, Barret Zoph, Quoc V. Le. "Searching for + * Activation Functions". arXiv preprint arXiv:1710.05941v2 (2017). + */ +template +class SwishLayer : public NeuronLayer { + public: + /** + * @param param provides SwishParameter swish_param, + * with SwishLayer options: + * - beta (\b optional, default 1). + * the value @f$ \beta @f$ in the @f$ y = x \sigma (\beta x) @f$. + */ + explicit SwishLayer(const LayerParameter& param) + : NeuronLayer(param), + sigmoid_layer_(new SigmoidLayer(param)), + sigmoid_input_(new Blob()), + sigmoid_output_(new Blob()) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Swish"; } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = x \sigma (\beta x) + * @f$. + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the sigmoid inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} + * = \frac{\partial E}{\partial y}(\beta y + + * \sigma (\beta x)(1 - \beta y)) + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// The internal SigmoidLayer + shared_ptr > sigmoid_layer_; + /// sigmoid_input_ stores the input of the SigmoidLayer. + shared_ptr > sigmoid_input_; + /// sigmoid_output_ stores the output of the SigmoidLayer. + shared_ptr > sigmoid_output_; + /// bottom vector holder to call the underlying SigmoidLayer::Forward + vector*> sigmoid_bottom_vec_; + /// top vector holder to call the underlying SigmoidLayer::Forward + vector*> sigmoid_top_vec_; +}; + +} // namespace caffe + +#endif // CAFFE_SWISH_LAYER_HPP_ diff --git a/src/caffe/layers/swish_layer.cpp b/src/caffe/layers/swish_layer.cpp new file mode 100644 index 00000000000..28935679d00 --- /dev/null +++ b/src/caffe/layers/swish_layer.cpp @@ -0,0 +1,68 @@ +#include +#include + +#include "caffe/layers/swish_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SwishLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + sigmoid_bottom_vec_.clear(); + sigmoid_bottom_vec_.push_back(sigmoid_input_.get()); + sigmoid_top_vec_.clear(); + sigmoid_top_vec_.push_back(sigmoid_output_.get()); + sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_); +} + +template +void SwishLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::Reshape(bottom, top); + sigmoid_input_->ReshapeLike(*bottom[0]); + sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_); +} + +template +void SwishLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + Dtype beta = this->layer_param_.swish_param().beta(); + caffe_copy(count, bottom_data, sigmoid_input_data); + caffe_scal(count, beta, sigmoid_input_data); + sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); + caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data); +} + +template +void SwishLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + Dtype beta = this->layer_param_.swish_param().beta(); + for (int i = 0; i < count; ++i) { + const Dtype swish_x = top_data[i]; + bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i] + * (1. - beta * swish_x)); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(SwishLayer); +#endif + +INSTANTIATE_CLASS(SwishLayer); +REGISTER_LAYER_CLASS(Swish); + +} // namespace caffe diff --git a/src/caffe/layers/swish_layer.cu b/src/caffe/layers/swish_layer.cu new file mode 100644 index 00000000000..c4fef53bf3a --- /dev/null +++ b/src/caffe/layers/swish_layer.cu @@ -0,0 +1,54 @@ +#include +#include + +#include "caffe/layers/swish_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SwishLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + Dtype beta = this->layer_param_.swish_param().beta(); + caffe_copy(count, bottom_data, sigmoid_input_data); + caffe_gpu_scal(count, beta, sigmoid_input_data); + sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); + caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data); +} + +template +__global__ void SwishBackward(const int n, const Dtype* in_diff, + const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff, + const Dtype beta) { + CUDA_KERNEL_LOOP(index, n) { + const Dtype swish_x = out_data[index]; + out_diff[index] = in_diff[index] * (beta * swish_x + + sigmoid_output_data[index] * (1 - beta * swish_x)); + } +} + +template +void SwishLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + Dtype beta = this->layer_param_.swish_param().beta(); + // NOLINT_NEXT_LINE(whitespace/operators) + SwishBackward<<>>( + count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta); + CUDA_POST_KERNEL_CHECK; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 22764abc33f..b9bb3f4dffe 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -322,7 +322,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param) +// LayerParameter next available layer-specific ID: 148 (last added: swish_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -415,6 +415,7 @@ message LayerParameter { optional SoftmaxParameter softmax_param = 125; optional SPPParameter spp_param = 132; optional SliceParameter slice_param = 126; + optional SwishParameter swish_param = 147; optional TanHParameter tanh_param = 127; optional ThresholdParameter threshold_param = 128; optional TileParameter tile_param = 138; @@ -1156,6 +1157,15 @@ message SoftmaxParameter { optional int32 axis = 2 [default = 1]; } +// Message that stores parameters used by SwishLayer +message SwishParameter { + // Beta parameter for the Swish activation function + // Described in: + // Prajit Ramachandran, Barret Zoph, Quoc V. Le. (2017). Searching for + // Activation Functions. https://arxiv.org/abs/1710.05941v2 + optional float beta = 1 [default = 1]; +} + message TanHParameter { enum Engine { DEFAULT = 0; diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index 180871a29ee..83d80fcd895 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -19,6 +19,7 @@ #include "caffe/layers/prelu_layer.hpp" #include "caffe/layers/relu_layer.hpp" #include "caffe/layers/sigmoid_layer.hpp" +#include "caffe/layers/swish_layer.hpp" #include "caffe/layers/tanh_layer.hpp" #include "caffe/layers/threshold_layer.hpp" @@ -344,6 +345,84 @@ TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) { this->blob_top_vec_); } +TYPED_TEST(NeuronLayerTest, TestSwish) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SwishLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-bottom_data[i]))); + } +} + +TYPED_TEST(NeuronLayerTest, TestSwishWithBeta) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "swish_param { beta: 1.5 }", &layer_param)); + SwishLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-1.5 * + bottom_data[i]))); + } +} + +TYPED_TEST(NeuronLayerTest, TestSwishAsLinear) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "swish_param { beta: 0.0 }", &layer_param)); + SwishLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / 2.0); + } +} + +TYPED_TEST(NeuronLayerTest, TestSwishGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SwishLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestSwishWithBetaGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "swish_param { beta: 1.5 }", &layer_param)); + SwishLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestSwishAsLinearGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "swish_param { beta: 0.0 }", &layer_param)); + SwishLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(NeuronLayerTest, TestTanH) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param;