Skip to content

Commit

Permalink
add ONNX OP sign, shrink and reciprocal
Browse files Browse the repository at this point in the history
  • Loading branch information
zihaomu committed Apr 7, 2022
1 parent 9aa6470 commit e36948c
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 7 deletions.
20 changes: 20 additions & 0 deletions modules/dnn/include/opencv2/dnn/all_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,26 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<ActivationLayerInt8> create(const LayerParams &params);
};

class CV_EXPORTS SignLayer : public ActivationLayer
{
public:
static Ptr<SignLayer> create(const LayerParams &params);
};

class CV_EXPORTS ShrinkLayer : public ActivationLayer
{
public:
float bias;
float lambd;
static Ptr<ShrinkLayer> create(const LayerParams &params);
};

class CV_EXPORTS ReciprocalLayer : public ActivationLayer
{
public:
static Ptr<ReciprocalLayer> create(const LayerParams &params);
};

/* Layers used in semantic segmentation */

class CV_EXPORTS CropLayer : public Layer
Expand Down
21 changes: 21 additions & 0 deletions modules/dnn/src/cuda/activations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ void selu(const Stream& stream, Span<T> output, View<T> input, T alpha, T gamma)
generic_op<T, SeluFunctor<T>>(stream, output, input, {alpha, gamma});
}

template <class T>
void sign(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, SignFunctor<T>>(stream, output, input);
}

template <class T>
void shrink(const Stream& stream, Span<T> output, View<T> input, T bias, T lambd) {
generic_op<T, ShrinkFunctor<T>>(stream, output, input, {bias, lambd});
}

template <class T>
void reciprocal(const Stream& stream, Span<T> output, View<T> input) {
generic_op<T, SignFunctor<T>>(stream, output, input);
}

template <class T>
void thresholdedrelu(const Stream& stream, Span<T> output, View<T> input, T alpha) {
generic_op<T, ThresholdedReluFunctor<T>>(stream, output, input, {alpha});
Expand Down Expand Up @@ -312,6 +327,9 @@ template void selu<__half>(const Stream&, Span<__half>, View<__half>, __half, __
template void thresholdedrelu<__half>(const Stream&, Span<__half>, View<__half>, __half);
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
template void exp<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
template void sign<__half>(const Stream&, Span<__half>, View<__half>);
template void shrink<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
template void reciprocal<__half>(const Stream&, Span<__half>, View<__half>);
#endif


Expand Down Expand Up @@ -351,6 +369,9 @@ template void selu<float>(const Stream&, Span<float>, View<float>, float, float)
template void thresholdedrelu<float>(const Stream&, Span<float>, View<float>, float);
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
template void exp<float>(const Stream&, Span<float>, View<float>, float, float);
template void sign<float>(const Stream&, Span<float>, View<float>);
template void shrink<float>(const Stream&, Span<float>, View<float>, float, float);
template void reciprocal<float>(const Stream&, Span<float>, View<float>);

template <class T, std::size_t N> static
void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
Expand Down
44 changes: 44 additions & 0 deletions modules/dnn/src/cuda/functors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,50 @@ struct DivFunctor {
CUDA4DNN_DEVICE T operator()(T x, T y) { return x / y; }
};

template <class T>
struct SignFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() {}
};

CUDA4DNN_DEVICE SignFunctor() : SignFunctor(Params{}) { }

CUDA4DNN_DEVICE T operator()(T value) {
return value > T(0) ? T(1) : (value < T(0) ? T(-1) : T(0));
}
};

template <class T>
struct ShrinkFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() : bias(0), lambd(0.5) { }
CUDA4DNN_HOST_DEVICE Params(T bias_, T lambd_) : bias(bias_), lambd(lambd_) { }
T bias, lambd;
};

CUDA4DNN_DEVICE ShrinkFunctor() : bias(0), lambd(0.5) { }
CUDA4DNN_DEVICE ShrinkFunctor(const Params& params) : bias{params.bias}, lambd{params.lambd} { }

CUDA4DNN_DEVICE T operator()(T value) {
return value > lambd ? value - bias : (value < -lambd ? value + bias : T(0));
}

T bias, lambd;
};

template <class T>
struct ReciprocalFunctor {
struct Params {
CUDA4DNN_HOST_DEVICE Params() {}
};

CUDA4DNN_DEVICE ReciprocalFunctor() : ReciprocalFunctor(Params{}) { }

CUDA4DNN_DEVICE T operator()(T value) {
return T(1.0f)/value;
}
};

}}}} /* namespace cv::dnn::cuda4dnn::kernels */

#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */
8 changes: 8 additions & 0 deletions modules/dnn/src/cuda4dnn/kernels/activations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void exp(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T normScale, T normShift);

template <class T>
void sign(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);

template <class T>
void shrink(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T bias, T lambd);

template <class T>
void reciprocal(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP */
46 changes: 46 additions & 0 deletions modules/dnn/src/cuda4dnn/primitives/activation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,52 @@ namespace cv { namespace dnn { namespace cuda4dnn {
const T normScale, normShift;
};

template <class T>
class ShrinkOp final : public BaseOp<ShrinkOp, T> {
public:
ShrinkOp(csl::Stream stream_, T bias_, T lambd_)
: stream(std::move(stream_)), bias{ bias_ }, lambd{ lambd_ } { }

void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::shrink<T>(stream, output, input, bias, lambd);
}

private:
csl::Stream stream;
const T bias, lambd;
};

template <class T>
class SignOp final : public BaseOp<SignOp, T> {
public:
SignOp(csl::Stream stream_)
: stream(std::move(stream_)) { }

void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::sign<T>(stream, output, input);
}

private:
csl::Stream stream;
};

template <class T>
class ReciprocalOp final : public BaseOp<ReciprocalOp, T> {
public:
ReciprocalOp(csl::Stream stream_)
: stream(std::move(stream_)) { }

void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
{
kernels::reciprocal<T>(stream, output, input);
}

private:
csl::Stream stream;
};

}}} /* namespace cv::dnn::cuda4dnn */

#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP */
3 changes: 3 additions & 0 deletions modules/dnn/src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(HardSwish, HardSwishLayer);
CV_DNN_REGISTER_LAYER_CLASS(Sin, SinLayer);
CV_DNN_REGISTER_LAYER_CLASS(Sinh, SinhLayer);
CV_DNN_REGISTER_LAYER_CLASS(Sign, SignLayer);
CV_DNN_REGISTER_LAYER_CLASS(Shrink, ShrinkLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softplus, SoftplusLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softsign, SoftsignLayer);
CV_DNN_REGISTER_LAYER_CLASS(Tan, TanLayer);
Expand All @@ -144,6 +146,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Silence, BlankLayer);
CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer);
CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer);
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);

CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
Expand Down
117 changes: 117 additions & 0 deletions modules/dnn/src/layers/elementwise_layers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,96 @@ struct ChannelsPReLUFunctor : public BaseFunctor
int64 getFLOPSPerElement() const { return 1; }
};

struct SignFunctor : public BaseDefaultFunctor<SignFunctor>
{
typedef SignLayer Layer;

bool supportBackend(int backendId, int)
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA;
}

inline float calculate(float x) const
{
return x > 0 ? 1 : (x < 0 ? -1 : 0);
}

#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
{
return make_cuda_node<cuda4dnn::SignOp>(target, stream);
}
#endif

int64 getFLOPSPerElement() const { return 1; }
};

template<>
const char* const SignFunctor::BaseDefaultFunctor<SignFunctor>::ocl_kernel_name = "SignForward";


struct ShrinkFunctor : public BaseDefaultFunctor<ShrinkFunctor>
{
typedef ShrinkLayer Layer;
float bias;
float lambd;

explicit ShrinkFunctor(float bias_ = 0.0f, float lambd_ = 0.5f) : bias(bias_), lambd(lambd_) {}

bool supportBackend(int backendId, int)
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA;
}

inline float calculate(float x) const
{
return x > lambd ? x - bias : (x < -lambd ? x + bias : 0);
}

#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
{
return make_cuda_node<cuda4dnn::ShrinkOp>(target, stream);
}
#endif

int64 getFLOPSPerElement() const { return 1; }
};

template<>
const char* const ShrinkFunctor::BaseDefaultFunctor<ShrinkFunctor>::ocl_kernel_name = "ShrinkForward";

struct ReciprocalFunctor : public BaseDefaultFunctor<ReciprocalFunctor>
{
typedef ReciprocalLayer Layer;

bool supportBackend(int backendId, int)
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA;
}

inline float calculate(float x) const
{
return 1.0/x;
}

#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
{
return make_cuda_node<cuda4dnn::ReciprocalOp>(target, stream);
}
#endif

int64 getFLOPSPerElement() const { return 1; }
};

template<>
const char* const ReciprocalFunctor::BaseDefaultFunctor<ReciprocalFunctor>::ocl_kernel_name = "ReciprocalForward";


#define ACTIVATION_CREATOR_FOR(_Layer, _Functor, ...) \
Ptr<_Layer> _Layer::create() { \
return return Ptr<_Layer>( new ElementWiseLayer<_Functor>(_Functor()) ); }
Expand Down Expand Up @@ -2611,5 +2701,32 @@ Ptr<Layer> ChannelsPReLULayer::create(const LayerParams& params)
return l;
}

Ptr<SignLayer> SignLayer::create(const LayerParams& params)
{
Ptr<SignLayer> l(new ElementWiseLayer<SignFunctor>());
l->setParamsFrom(params);

return l;
}

Ptr<ReciprocalLayer> ReciprocalLayer::create(const LayerParams& params)
{
Ptr<ReciprocalLayer> l(new ElementWiseLayer<ReciprocalFunctor>());
l->setParamsFrom(params);

return l;
}

Ptr<ShrinkLayer> ShrinkLayer::create(const LayerParams& params)
{
float bias = params.get<float>("bias", 0.f);
float lambd = params.get<float>("lambd", 0.5f);
Ptr<ShrinkLayer> l(new ElementWiseLayer<ShrinkFunctor>(ShrinkFunctor(bias, lambd)));
l->setParamsFrom(params);
l->bias = bias;
l->lambd = lambd;

return l;
}
}
}
4 changes: 2 additions & 2 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3675,8 +3675,8 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)

std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",
"Cosh", "Dropout", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish",
"Identity", "Log", "Round", "Selu", "Sigmoid", "Sin", "Sinh", "Softmax",
"Softplus", "Softsign", "Sqrt", "Tan", "ThresholdedRelu"};
"Identity", "Log", "Round", "Reciprocal", "Selu", "Sign", "Sigmoid", "Sin", "Sinh", "Softmax",
"Softplus", "Softsign", "Shrink", "Sqrt", "Tan", "ThresholdedRelu"};
for (const auto& name : simpleLayers)
{
dispatch[name] = &ONNXImporter::parseSimpleLayers;
Expand Down
23 changes: 23 additions & 0 deletions modules/dnn/src/opencl/activations.cl
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,26 @@ __kernel void ThresholdedReluForward(const int n, __global T* in, __global T* ou
if(index < n)
out[index] = (in[index] > alpha ? in[index] : 0.f);
}

__kernel void ShrinkForward(const int n, __global T* in, __global T* out,
const KERNEL_ARG_DTYPE bias,
const KERNEL_ARG_DTYPE lambd)
{
int index = get_global_id(0);
if(index < n)
out[index] = in[index] < -lambd ? in[index] + bias : (in[index] > lambd ? in[index] - bias : 0.f);
}

__kernel void SignForward(const int n, __global T* in, __global T* out)
{
int index = get_global_id(0);
if(index < n)
out[index] = in[index] > 0.f ? 1.0f : (in[index] < 0.f) ? -1.0f : 0.0f);
}

__kernel void ReciprocalForward(const int n, __global T* in, __global T* out)
{
int index = get_global_id(0);
if(index < n)
out[index] = 1.0f/in[index];
}
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,6 @@
"test_range_float_type_positive_delta_expanded",
"test_range_int32_type_negative_delta",
"test_range_int32_type_negative_delta_expanded",
"test_reciprocal",
"test_reciprocal_example",
"test_reduce_sum_default_axes_keepdims_example",
"test_reduce_sum_default_axes_keepdims_random",
"test_reduce_sum_do_not_keepdims_example",
Expand Down Expand Up @@ -479,9 +477,6 @@
"test_shape_start_1_end_2",
"test_shape_start_1_end_negative_1",
"test_shape_start_negative_1",
"test_shrink_hard",
"test_shrink_soft",
"test_sign",
"test_simple_rnn_batchwise",
"test_simple_rnn_defaults",
"test_simple_rnn_with_initial_bias",
Expand Down

0 comments on commit e36948c

Please sign in to comment.