Skip to content

Commit

Permalink
Add fused convolution and mish layer support. (Tencent#1761)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiliu6 authored May 17, 2020
1 parent c35bf2f commit 3bfabf1
Show file tree
Hide file tree
Showing 38 changed files with 424 additions and 4 deletions.
15 changes: 15 additions & 0 deletions src/layer/arm/convolution_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ int Convolution_arm::create_pipeline(const Option& opt)
ncnn::ParamDict pd;
activation->load_param(pd);
}
else if (activation_type == 5)
{
activation = ncnn::create_layer(ncnn::LayerType::Mish);

ncnn::ParamDict pd;
activation->load_param(pd);
}

if (activation)
{
Expand Down Expand Up @@ -993,6 +1000,10 @@ int Convolution_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
sum = static_cast<float>(sum * tanh(log(exp(sum) + 1.f)));
}

outptr[j] = sum;
}
Expand Down Expand Up @@ -1622,6 +1633,10 @@ int Convolution_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
sum = static_cast<float>(sum * tanh(log(exp(sum) + 1.f)));
}

outptr[j] = float32_to_bfloat16(sum);
}
Expand Down
11 changes: 11 additions & 0 deletions src/layer/arm/convolutiondepthwise_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ int ConvolutionDepthWise_arm::create_pipeline(const Option& opt)
ncnn::ParamDict pd;
activation->load_param(pd);
}
else if (activation_type == 5)
{
activation = ncnn::create_layer(ncnn::LayerType::Mish);

ncnn::ParamDict pd;
activation->load_param(pd);
}

if (activation)
{
Expand Down Expand Up @@ -753,6 +760,10 @@ int ConvolutionDepthWise_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blo
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
sum = static_cast<float>(sum * tanh(log(exp(sum) + 1.f)));
}

outptr[j] = float32_to_bfloat16(sum);
}
Expand Down
162 changes: 162 additions & 0 deletions src/layer/arm/mish_arm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "mish_arm.h"

#if __ARM_NEON
#include <arm_neon.h>
#include "neon_mathfun.h"
#endif // __ARM_NEON

#include <math.h>

namespace ncnn {

DEFINE_LAYER_CREATOR(Mish_arm)

Mish_arm::Mish_arm()
{
#if __ARM_NEON
support_packing = true;
#endif // __ARM_NEON

support_bf16_storage = true;
}

int Mish_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
if (opt.use_bf16_storage)
return forward_inplace_bf16s(bottom_top_blob, opt);

int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int channels = bottom_top_blob.c;
int size = w * h;
int elempack = bottom_top_blob.elempack;

#if __ARM_NEON
if (elempack == 4)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q=0; q<channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

for (int i=0; i<size; i++)
{
float32x4_t _p = vld1q_f32(ptr);
_p = vmulq_f32(_p, tanh_ps(log_ps(vaddq_f32(exp_ps(_p), vdupq_n_f32(1.f)))));
vst1q_f32(ptr, _p);
ptr += 4;
}
}

return 0;
}
#endif // __ARM_NEON

#pragma omp parallel for num_threads(opt.num_threads)
for (int q=0; q<channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

#if __ARM_NEON
int nn = size >> 2;
int remain = size - (nn << 2);
#else
int remain = size;
#endif // __ARM_NEON

#if __ARM_NEON
for (; nn>0; nn--)
{
float32x4_t _p = vld1q_f32(ptr);
_p = vmulq_f32(_p, tanh_ps(log_ps(vaddq_f32(exp_ps(_p), vdupq_n_f32(1.f)))));
vst1q_f32(ptr, _p);
ptr += 4;
}
#endif // __ARM_NEON
for (; remain>0; remain--)
{
*ptr = *ptr * tanh(log(exp(*ptr) + 1.f));
ptr++;
}
}

return 0;
}

int Mish_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int channels = bottom_top_blob.c;
int size = w * h;
int elempack = bottom_top_blob.elempack;

#if __ARM_NEON
if (elempack == 4)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q=0; q<channels; q++)
{
unsigned short* ptr = bottom_top_blob.channel(q);

for (int i=0; i<size; i++)
{
float32x4_t _p = vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(ptr), 16));
_p = vmulq_f32(_p, tanh_ps(log_ps(vaddq_f32(exp_ps(_p), vdupq_n_f32(1.f)))));
vst1_u16(ptr, vshrn_n_u32(vreinterpretq_u32_f32(_p), 16));
ptr += 4;
}
}

return 0;
}
#endif // __ARM_NEON

#pragma omp parallel for num_threads(opt.num_threads)
for (int q=0; q<channels; q++)
{
unsigned short* ptr = bottom_top_blob.channel(q);

#if __ARM_NEON
int nn = size >> 2;
int remain = size - (nn << 2);
#else
int remain = size;
#endif // __ARM_NEON

#if __ARM_NEON
for (; nn>0; nn--)
{
float32x4_t _p = vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(ptr), 16));
_p = vmulq_f32(_p, tanh_ps(log_ps(vaddq_f32(exp_ps(_p), vdupq_n_f32(1.f)))));
vst1_u16(ptr, vshrn_n_u32(vreinterpretq_u32_f32(_p), 16));
ptr += 4;
}
#endif // __ARM_NEON
for (; remain>0; remain--)
{
float v = bfloat16_to_float32(*ptr);
v = v * tanh(log(exp(v) + 1.f));
*ptr = float32_to_bfloat16(v);
ptr++;
}
}

return 0;
}

} // namespace ncnn
35 changes: 35 additions & 0 deletions src/layer/arm/mish_arm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_MISH_ARM_H
#define LAYER_MISH_ARM_H

#include "mish.h"

namespace ncnn {

class Mish_arm : virtual public Mish
{
public:
Mish_arm();

virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;

protected:
int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const;
};

} // namespace ncnn

#endif // LAYER_MISH_ARM_H
8 changes: 8 additions & 0 deletions src/layer/arm/neon_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ static inline float activation_ss(float v, int activation_type, const ncnn::Mat&
{
v = 1.f / (1.f + exp(-v));
}
else if (activation_type == 5)
{
v = v * tanh(log(exp(v) + 1.f));
}

return v;
}
Expand Down Expand Up @@ -78,6 +82,10 @@ static inline float32x4_t activation_ps(float32x4_t _v, int activation_type, con
// _outp = vmulq_f32(vrecpsq_f32(_v, _outp), _outp);
_v = _outp;
}
else if (activation_type == 5)
{
_v = vmulq_f32(_v, tanh_ps(log_ps(vaddq_f32(exp_ps(_v), vdupq_n_f32(1.f)))));
}

return _v;
}
Expand Down
12 changes: 12 additions & 0 deletions src/layer/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ int Convolution::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
const float MISH_THRESHOLD = 20;
float x = sum, y;
if (x > MISH_THRESHOLD)
y = x;
else if (x < -MISH_THRESHOLD)
y = expf(x);
else
y = logf(expf(x) + 1);
sum = static_cast<float>(x * tanh(y));
}

outptr[j] = sum;
}
Expand Down
24 changes: 24 additions & 0 deletions src/layer/convolutiondepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ int ConvolutionDepthWise::forward(const Mat& bottom_blob, Mat& top_blob, const O
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
const float MISH_THRESHOLD = 20;
float x = sum, y;
if (x > MISH_THRESHOLD)
y = x;
else if (x < -MISH_THRESHOLD)
y = expf(x);
else
y = logf(expf(x) + 1);
sum = static_cast<float>(x * tanh(y));
}

outptr[j] = sum;
}
Expand Down Expand Up @@ -313,6 +325,18 @@ int ConvolutionDepthWise::forward(const Mat& bottom_blob, Mat& top_blob, const O
{
sum = static_cast<float>(1.f / (1.f + exp(-sum)));
}
else if (activation_type == 5)
{
const float MISH_THRESHOLD = 20;
float x = sum, y;
if (x > MISH_THRESHOLD)
y = x;
else if (x < -MISH_THRESHOLD)
y = expf(x);
else
y = logf(expf(x) + 1);
sum = static_cast<float>(x * tanh(y));
}

outptr[j] = sum;
}
Expand Down
4 changes: 4 additions & 0 deletions src/layer/vulkan/shader/convolution.comp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ void main()
{
sum = afp(1.f) / (afp(1.f) + exp(-sum));
}
if (activation_type == 5)
{
sum = sum * tanh(log(exp(sum) + afp(1.f)));
}

#if NCNN_image_shader
image3d_st1(top_blob, ivec3(gx, gy, gz), sum);
Expand Down
4 changes: 4 additions & 0 deletions src/layer/vulkan/shader/convolution_1x1s1d1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ void main()
{
sum = afp(1.f) / (afp(1.f) + exp(-sum));
}
if (activation_type == 5)
{
sum = sum * tanh(log(exp(sum) + afp(1.f)));
}

#if NCNN_image_shader
image3d_st1(top_blob, ivec3(gx, gy, gz), sum.r);
Expand Down
4 changes: 4 additions & 0 deletions src/layer/vulkan/shader/convolution_pack1to4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ void main()
{
sum = afp(1.f) / (afp(1.f) + exp(-sum));
}
if (activation_type == 5)
{
sum = sum * tanh(log(exp(sum) + afp(1.f)));
}

#if NCNN_image_shader
image3d_st4(top_blob, ivec3(gx, gy, gz), sum);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/convolution_pack1to8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ void main()
sum[0] = afp(1.f) / (afp(1.f) + exp(-sum[0]));
sum[1] = afp(1.f) / (afp(1.f) + exp(-sum[1]));
}
if (activation_type == 5)
{
sum[0] = sum[0] * tanh(log(exp(sum[0]) + afp(1.f)));
sum[1] = sum[1] * tanh(log(exp(sum[1]) + afp(1.f)));
}

#if NCNN_image_shader
image3d_st8(top_blob, ivec3(gx, gy, gz), sum);
Expand Down
4 changes: 4 additions & 0 deletions src/layer/vulkan/shader/convolution_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ void main()
{
sum = afp(1.f) / (afp(1.f) + exp(-sum));
}
if (activation_type == 5)
{
sum = sum * tanh(log(exp(sum) + afp(1.f)));
}

#if NCNN_image_shader
image3d_st4(top_blob, ivec3(gx, gy, gz), sum);
Expand Down
Loading

0 comments on commit 3bfabf1

Please sign in to comment.