Skip to content

Commit

Permalink
common: introduce tanh eltwise operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Fomenko, Evarist M committed Jun 18, 2017
1 parent 8044c61 commit 11918d7
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 46 deletions.
8 changes: 4 additions & 4 deletions include/mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,17 +409,17 @@ mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(

/** Initializes a @p eltwise_desc for forward propagation using @p prop_kind
* (possible values are #mkldnn_forward_training or #mkldnn_forward_inference),
* @p alg_kind algorithm (possible values: #mkldnn_eltwise_relu), memory
* descriptor @p data_desc, and @p alpha, @p beta parameters.
* @p alg_kind algorithm, memory descriptor @p data_desc, and @p alpha,
* @p beta parameters.
* @sa mkldnn_eltwise_desc_t for details */
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(
mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind,
mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
double alpha, double beta);

/** Initializes a @p eltwise_desc for backward propagation using @p alg_kind
* algorithm (possible values: #mkldnn_eltwise_relu), memory descriptors
* @p diff_data_desc and @p data_desc, and @p alpha, @p beta parameters.
* algorithm memory descriptors @p diff_data_desc and @p data_desc, and
* @p alpha, @p beta parameters.
* @sa mkldnn_eltwise_desc_t for details */
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(
mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind,
Expand Down
1 change: 1 addition & 0 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ enum algorithm {
convolution_direct = c_api::mkldnn_convolution_direct,
convolution_winograd = c_api::mkldnn_convolution_winograd,
eltwise_relu = c_api::mkldnn_eltwise_relu,
eltwise_tanh = c_api::mkldnn_eltwise_tanh,
lrn_across_channels = c_api::mkldnn_lrn_across_channels,
lrn_within_channel = c_api::mkldnn_lrn_within_channel,
pooling_max = c_api::mkldnn_pooling_max,
Expand Down
6 changes: 5 additions & 1 deletion include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ typedef enum {
mkldnn_convolution_winograd = 2,
/** Eltwise: ReLU */
mkldnn_eltwise_relu = 8,
/** Eltwise: hyperbolic tangent non-linearity (tanh) */
mkldnn_eltwise_tanh = 9,
/** Max pooling */
mkldnn_pooling_max = 34,
/** Average pooling include padding */
Expand Down Expand Up @@ -446,7 +448,8 @@ typedef struct {
* #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
*/
mkldnn_prop_kind_t prop_kind;
/** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu */
/** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
* #mkldnn_eltwise_tanh */
mkldnn_alg_kind_t alg_kind;
/** Source and destination memory descriptor. */
mkldnn_memory_desc_t data_desc;
Expand All @@ -455,6 +458,7 @@ typedef struct {
/** Algorithm specific parameter.
* Accordance table:
* - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
* - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
*/
double alpha, beta;
/** Scaling factor for negative values. Stored as double-precision, but
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace alg_kind {
const alg_kind_t convolution_direct = mkldnn_convolution_direct;
const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
const alg_kind_t pooling_max = mkldnn_pooling_max;
const alg_kind_t pooling_avg = mkldnn_pooling_avg;
const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
Expand Down
2 changes: 1 addition & 1 deletion src/common/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
&& !any_null(eltwise_desc, data_desc)
&& one_of(prop_kind, forward_training, forward_inference,
backward_data)
&& one_of(alg_kind, eltwise_relu)
&& one_of(alg_kind, eltwise_relu, eltwise_tanh)
&& implication(prop_kind == backward_data, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;

Expand Down
48 changes: 45 additions & 3 deletions src/cpu/ref_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ namespace mkldnn {
namespace impl {
namespace cpu {

using namespace alg_kind;

namespace {
template <typename T, typename A> T relu_fwd(T s, A alpha) {
return s > 0 ? s : s * alpha;
}
template <typename T, typename A> T relu_bwd(T dd, T s, A alpha) {
return s > 0 ? dd : dd * alpha;
}

template <typename T> T tanh_fwd(T s) {
T e = ::expf(2*s); /* maybe replace with -2*s? */
return (e - 1) / (e + 1);
}
template <typename T> T tanh_bwd(T dd, T s) {
T th = tanh_fwd(s);
return dd * (1 - th * th);
}
}

template <impl::data_type_t data_type>
void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
Expand All @@ -37,6 +57,7 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
const int C = conf_.C();
const int H = conf_.H();
const int W = conf_.W();
const auto alg_kind = conf_.desc()->alg_kind;
const double alpha = conf_.desc()->alpha;

# pragma omp parallel for collapse(4) schedule(static)
Expand All @@ -47,7 +68,11 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_generic() {
auto d_off = data_d.off(n, c, h, w);
data_t s = src[d_off];
data_t &d = dst[d_off];
d = (s > 0) ? s : s * alpha;
switch (alg_kind) {
case eltwise_relu: d = relu_fwd(s, alpha); break;
case eltwise_tanh: d = tanh_fwd(s); break;
default: assert(!"unknown eltwise alg_kind");
}
}
}
}
Expand All @@ -62,14 +87,19 @@ void ref_eltwise_fwd_t<data_type>::execute_forward_dense() {
const memory_desc_wrapper data_d(conf_.src_pd());

const size_t nelems = data_d.nelems();
const auto alg_kind = conf_.desc()->alg_kind;
const double alpha = conf_.desc()->alpha;

src += data_d.blocking_desc().offset_padding;
dst += data_d.blocking_desc().offset_padding;

# pragma omp parallel for schedule(static)
for (int e = 0; e < nelems; ++e) {
dst[e] = src[e] * ((src[e] > 0) ? 1. : alpha);
switch (alg_kind) {
case eltwise_relu: dst[e] = relu_fwd(src[e], alpha); break;
case eltwise_tanh: dst[e] = tanh_fwd(src[e]); break;
default: assert(!"unknown eltwise alg_kind");
}
}
}

Expand All @@ -86,6 +116,7 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_generic() {
const int C = conf_.C();
const int H = conf_.H();
const int W = conf_.W();
const auto alg_kind = conf_.desc()->alg_kind;
const double alpha = conf_.desc()->alpha;

# pragma omp parallel for collapse(4) schedule(static)
Expand All @@ -98,7 +129,11 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_generic() {
data_t s = src[data_off];
data_t dd = diff_dst[diff_data_off];
data_t &ds = diff_src[diff_data_off];
ds = dd * ((s > 0) ? 1. : alpha);
switch (alg_kind) {
case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
case eltwise_tanh: ds = tanh_bwd(dd, s); break;
default: assert(!"unknown eltwise alg_kind");
}
}
}
}
Expand All @@ -115,6 +150,7 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense() {
const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());

const size_t nelems = data_d.nelems();
const auto alg_kind = conf_.desc()->alg_kind;
const double alpha = conf_.desc()->alpha;

src += data_d.blocking_desc().offset_padding;
Expand All @@ -124,6 +160,12 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense() {
# pragma omp parallel for schedule(static)
for (int e = 0; e < nelems; ++e) {
diff_src[e] = diff_dst[e] * ((src[e] > 0) ? 1. : alpha);
switch (alg_kind) {
case eltwise_relu: diff_src[e] = relu_bwd(diff_dst[e], src[e], alpha);
break;
case eltwise_tanh: diff_src[e] = tanh_bwd(diff_dst[e], src[e]); break;
default: assert(!"unknown eltwise alg_kind");
}
}
}

Expand Down
106 changes: 69 additions & 37 deletions tests/gtests/test_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@

namespace mkldnn {

template <typename T> T relu_fwd(T s, T alpha) {
return s > 0 ? s : s * alpha;
}
template <typename T> T relu_bwd(T dd, T s, T alpha) {
return s > 0 ? dd : dd * alpha;
}

template <typename T> T tanh_fwd(T s) {
T e = ::expf(2*s); /* maybe replace with -2*s? */
return (e - 1) / (e + 1);
}
template <typename T> T tanh_bwd(T dd, T s) {
T th = tanh_fwd(s);
return dd * (1 - th * th);
}

template <typename data_t>
struct eltwise_test_params {
engine::kind engine_kind;
Expand All @@ -32,8 +48,8 @@ struct eltwise_test_params {
};

template <typename data_t>
void check_eltwise_fwd(data_t alpha, const memory::desc &md,
const memory &src, const memory &dst)
void check_eltwise_fwd(const eltwise_test_params<data_t> &p,
const memory::desc &md, const memory &src, const memory &dst)
{
data_t *src_data = (data_t *)src.get_data_handle();
data_t *dst_data = (data_t *)dst.get_data_handle();
Expand All @@ -47,13 +63,20 @@ void check_eltwise_fwd(data_t alpha, const memory::desc &md,
size_t W = md.data.dims[3];
for (size_t i = 0; i < N * C * H * W; ++i) {
data_t s = src_data[i];
EXPECT_NEAR(dst_data[i], s > 0 ? s : s * alpha, 1.e-7);
data_t ref_d = 0;
switch (p.alg_kind) {
case eltwise_relu: ref_d = relu_fwd(s, p.alpha); break;
case eltwise_tanh: ref_d = tanh_fwd(s); break;
default: assert(!"unknown alg_kind");
}
EXPECT_NEAR(dst_data[i], ref_d, 1.e-6);
}
}

template <typename data_t>
void check_eltwise_bwd(data_t alpha, const memory::desc &md,
const memory &src, const memory &diff_dst, const memory &diff_src)
void check_eltwise_bwd(const eltwise_test_params<data_t> &p,
const memory::desc &md, const memory &src, const memory &diff_dst,
const memory &diff_src)
{
data_t *src_data = (data_t *)src.get_data_handle();
data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
Expand All @@ -72,8 +95,13 @@ void check_eltwise_bwd(data_t alpha, const memory::desc &md,
for (size_t i = 0; i < N * C * H * W; ++i) {
data_t ref_s = src_data[map_index(data_d, i)];
data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
data_t ref_ds = ref_dd * ((ref_s > 0) ? 1. : alpha);
EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-7);
data_t ref_ds = 0;
switch (p.alg_kind) {
case eltwise_relu: ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
case eltwise_tanh: ref_ds = tanh_bwd(ref_dd, ref_s); break;
default: assert(!"unknown alg_kind");
}
EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-6);
}
}

Expand Down Expand Up @@ -133,7 +161,7 @@ class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>
auto s = stream(stream::kind::lazy);
s.submit(pipeline).wait();

check_eltwise_fwd(p.alpha, *data_desc, *src, *dst);
check_eltwise_fwd(p, *data_desc, *src, *dst);
}

void Backward() {
Expand All @@ -155,8 +183,7 @@ class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>
auto s = stream(stream::kind::lazy);
s.submit(pipeline).wait();

check_eltwise_bwd(p.alpha, *data_desc, *src, *diff_dst,
*diff_src);
check_eltwise_bwd(p, *data_desc, *src, *diff_dst, *diff_src);
}
};

Expand All @@ -176,46 +203,51 @@ TEST_P(eltwise_test_float, TestsEltwise)
EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
alpha, beta, {mb, c, h, w} }

#define PARAMS_ALL_ALG(...) \
PARAMS(eltwise_relu, __VA_ARGS__), \
PARAMS(eltwise_tanh, __VA_ARGS__) \


#define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
str, eltwise_test_float, ::testing::Values(__VA_ARGS__))

INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
);

INST_TEST_CASE(Simple_NCHW,
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
);

INST_TEST_CASE(Simple,
PARAMS(eltwise_relu, nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
PARAMS(eltwise_relu, nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
PARAMS(eltwise_relu, nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
PARAMS(eltwise_relu, nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
PARAMS(eltwise_relu, nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
PARAMS(eltwise_relu, nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
);

INST_TEST_CASE(AlexNet_NCHW,
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
PARAMS(eltwise_relu, nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
);

}

0 comments on commit 11918d7

Please sign in to comment.