Skip to content

Commit

Permalink
cpu: decouple ref and jit lnorm kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Fomenko, Evarist M committed Apr 22, 2020
1 parent b94a27a commit 9dfa9fd
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,50 @@
* limitations under the License.
*******************************************************************************/

#ifndef CPU_JIT_UNI_LAYER_NORMALIZATION_KERNELS_HPP
#define CPU_JIT_UNI_LAYER_NORMALIZATION_KERNELS_HPP

#include "cpu_layer_normalization_pd.hpp"
#include "jit_generator.hpp"

#include "jit_simple_layer_normalization_kernels.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace lnorm_utils {

class statistics_kernel_t : jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(
jit_uni_layer_normalization_fwd_t::statistics_kernel);
statistics_kernel_t(const layer_normalization_pd_t *pd)
: C_(pd->norm_axis()), ker_(nullptr) {
if (mayiuse(avx2)) { generate(); }
}
~statistics_kernel_t() {}

void operator()(const float *src, float *mean, float *var) {
if (ker_) {
ker_args args;
args.src = src;
args.mean = mean;
args.var = var;
ker_(&args);
} else {
float v_mean = 0;
PRAGMA_OMP_SIMD(reduction(+ : v_mean))
for (dim_t c = 0; c < C_; ++c) {
v_mean += src[c];
}
v_mean /= C_;
struct jit_statistics_kernel_t : statistics_kernel_t, jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(lnorm_utils::jit_statistics_kernel_t);

float v_variance = 0;
PRAGMA_OMP_SIMD(reduction(+ : v_variance))
for (dim_t c = 0; c < C_; ++c) {
auto m = src[c] - v_mean;
v_variance += m * m;
}
v_variance /= C_;
jit_statistics_kernel_t(const layer_normalization_pd_t *pd)
: statistics_kernel_t(pd) {
assert(mayiuse(avx2));
generate();
}

*mean = v_mean;
*var = v_variance;
}
virtual void operator()(
const float *src, float *mean, float *var) const override {
assert(ker_);
ker_args_t args;
args.src = src;
args.mean = mean;
args.var = var;
ker_(&args);
}

private:
int C_;
int unroll_factor_ = 8;
int simd_w_ = 8;

struct ker_args {
struct ker_args_t {
const float *src;
float *mean;
float *var;
};
void (*ker_)(const ker_args *args);
void (*ker_)(const ker_args_t *args) = nullptr;

void generate() {
using namespace Xbyak;

preamble();
#define PARAM_OFF(x) offsetof(ker_args, x)
#define PARAM_OFF(x) offsetof(ker_args_t, x)
mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
Expand Down Expand Up @@ -178,54 +157,38 @@ class statistics_kernel_t : jit_generator {
Xbyak::Ymm ymm_mean = Xbyak::Ymm(15);
};

class data_kernel_t : jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(
jit_uni_layer_normalization_fwd_t::data_kernel);
data_kernel_t(const layer_normalization_pd_t *pd)
: C_(pd->norm_axis())
, use_scaleshift_(pd->use_scaleshift())
, eps_(pd->desc()->layer_norm_epsilon)
, ker_(nullptr) {
if (mayiuse(avx2)) { generate(); }
struct jit_data_kernel_t : data_kernel_t, jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(lnorm_utils::jit_data_kernel_t);

jit_data_kernel_t(const layer_normalization_pd_t *pd) : data_kernel_t(pd) {
assert(mayiuse(avx2));
generate();
}
~data_kernel_t() {}
void operator()(const float *src, float *dst, const float *ss,
const float *mean, const float *var) {
if (ker_) {
ker_args args;
args.src = src;
args.dst = dst;
args.ss = ss;
args.mean = mean;
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
} else {
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
PRAGMA_OMP_SIMD()
for (dim_t c = 0; c < C_; ++c) {
const float sm = (use_scaleshift_ ? ss[c] : 1.0f) * inv_sqrtvar;
const float sv = use_scaleshift_ ? ss[C_ + c] : 0;
dst[c] = sm * (src[c] - *mean) + sv;
}
}

virtual void operator()(const float *src, float *dst, const float *ss,
const float *mean, const float *var) const override {
assert(ker_);
ker_args_t args;
args.src = src;
args.dst = dst;
args.ss = ss;
args.mean = mean;
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
}

private:
int C_;
bool use_scaleshift_;
const float eps_;
int simd_w_ = 8;

struct ker_args {
struct ker_args_t {
const float *src;
float *dst;
const float *ss;
const float *mean;
const float *inv_sqrtvar;
};
void (*ker_)(const ker_args *args);
void (*ker_)(const ker_args_t *args) = nullptr;

void load(Xbyak::Ymm &ymm_src, Xbyak::Reg64 reg_src, int nelems,
size_t offt) {
Expand All @@ -251,7 +214,7 @@ class data_kernel_t : jit_generator {
using namespace Xbyak;

preamble();
#define PARAM_OFF(x) offsetof(ker_args, x)
#define PARAM_OFF(x) offsetof(ker_args_t, x)
mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
mov(reg_ss, ptr[reg_param + PARAM_OFF(ss)]);
Expand Down Expand Up @@ -305,54 +268,42 @@ class data_kernel_t : jit_generator {
Xbyak::Ymm ymm_mean = Xbyak::Ymm(15);
};

class diff_ss_kernel_t : jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(
jit_uni_layer_normalization_fwd_t::diff_dst_kernel);
diff_ss_kernel_t(const layer_normalization_pd_t *pd)
: C_(pd->norm_axis())
, eps_(pd->desc()->layer_norm_epsilon)
, ker_(nullptr) {
if (mayiuse(avx2)) { generate(); }
struct jit_diff_ss_kernel_t : diff_ss_kernel_t, jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(lnorm_utils::jit_diff_ss_kernel_t);

jit_diff_ss_kernel_t(const layer_normalization_pd_t *pd)
: diff_ss_kernel_t(pd) {
assert(mayiuse(avx2));
generate();
}
~diff_ss_kernel_t() {}
void operator()(const float *src, const float *diff_dst, float *diff_gamma,
float *diff_beta, const float *mean, const float *var) {
if (ker_) {
ker_args args;
args.src = src;
args.diff_dst = diff_dst;
args.diff_gamma = diff_gamma;
args.diff_beta = diff_beta;
args.mean = mean;
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
} else {
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
PRAGMA_OMP_SIMD()
for (dim_t c = 0; c < C_; c++) {
float dd = diff_dst[c];
diff_gamma[c] += (src[c] - *mean) * dd * inv_sqrtvar;
diff_beta[c] += dd;
}
}

virtual void operator()(const float *src, const float *diff_dst,
float *diff_gamma, float *diff_beta, const float *mean,
const float *var) const override {
assert(ker_);
ker_args_t args;
args.src = src;
args.diff_dst = diff_dst;
args.diff_gamma = diff_gamma;
args.diff_beta = diff_beta;
args.mean = mean;
float inv_sqrtvar = 1. / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
}

private:
int C_;
const float eps_;
int simd_w_ = 8;

struct ker_args {
struct ker_args_t {
const float *src;
const float *diff_dst;
float *diff_gamma;
float *diff_beta;
const float *mean;
const float *inv_sqrtvar;
};
void (*ker_)(const ker_args *args);
void (*ker_)(const ker_args_t *args) = nullptr;

void load(Xbyak::Ymm &ymm_src, Xbyak::Reg64 reg_src, int nelems,
size_t offt) {
Expand All @@ -378,7 +329,7 @@ class diff_ss_kernel_t : jit_generator {
using namespace Xbyak;

preamble();
#define PARAM_OFF(x) offsetof(ker_args, x)
#define PARAM_OFF(x) offsetof(ker_args_t, x)
mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
mov(reg_diff_gamma, ptr[reg_param + PARAM_OFF(diff_gamma)]);
Expand Down Expand Up @@ -435,72 +386,42 @@ class diff_ss_kernel_t : jit_generator {
Xbyak::Ymm ymm_mean = Xbyak::Ymm(15);
};

class diff_data_kernel_t : jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(
jit_uni_layer_normalization_fwd_t::diff_data_kernel);
diff_data_kernel_t(const layer_normalization_pd_t *pd)
: C_(pd->norm_axis())
, eps_(pd->desc()->layer_norm_epsilon)
, calculate_diff_stats_(!pd->use_global_stats())
, use_scaleshift_(pd->use_scaleshift())
, ker_(nullptr) {
if (mayiuse(avx2)) { generate(); }
struct jit_diff_data_kernel_t : diff_data_kernel_t, jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(lnorm_utils::jit_diff_data_kernel_t);

jit_diff_data_kernel_t(const layer_normalization_pd_t *pd)
: diff_data_kernel_t(pd) {
assert(mayiuse(avx2));
generate();
}
~diff_data_kernel_t() {}
void operator()(const float *src, const float *diff_dst, float *diff_src,
const float *ss, const float *mean, const float *var) {
if (ker_) {
ker_args args;
args.src = src;
args.diff_dst = diff_dst;
args.diff_src = diff_src;
args.ss = ss;
args.mean = mean;
float inv_sqrtvar = 1.f / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
} else {
float inv_sqrtvar = 1.f / sqrtf(*var + eps_);
float dd_gamma = 0, dd_gamma_x = 0;
if (calculate_diff_stats_) {
PRAGMA_OMP_SIMD(reduction(+ : dd_gamma, dd_gamma_x))
for (dim_t c = 0; c < C_; c++) {
float gamma = use_scaleshift_ ? ss[c] : 1;
dd_gamma += diff_dst[c] * gamma;
dd_gamma_x += diff_dst[c] * gamma * (src[c] - *mean);
}
dd_gamma_x *= inv_sqrtvar;
}
PRAGMA_OMP_SIMD()
for (dim_t c = 0; c < C_; c++) {
float gamma = use_scaleshift_ ? ss[c] : 1;
float v_diff_src = diff_dst[c] * gamma;
if (calculate_diff_stats_)
v_diff_src -= dd_gamma / C_
+ (src[c] - *mean) * dd_gamma_x * inv_sqrtvar / C_;
v_diff_src *= inv_sqrtvar;
diff_src[c] = v_diff_src;
}
}

virtual void operator()(const float *src, const float *diff_dst,
float *diff_src, const float *ss, const float *mean,
const float *var) const override {
assert(ker_);
ker_args_t args;
args.src = src;
args.diff_dst = diff_dst;
args.diff_src = diff_src;
args.ss = ss;
args.mean = mean;
float inv_sqrtvar = 1.f / sqrtf(*var + eps_);
args.inv_sqrtvar = &inv_sqrtvar;
ker_(&args);
}

private:
int C_;
const float eps_;
bool calculate_diff_stats_;
bool use_scaleshift_;
int simd_w_ = 8;

struct ker_args {
struct ker_args_t {
const float *src;
const float *diff_dst;
float *diff_src;
const float *ss;
const float *mean;
const float *inv_sqrtvar;
};
void (*ker_)(const ker_args *args);
void (*ker_)(const ker_args_t *args) = nullptr;

void load(Xbyak::Ymm &ymm_src, Xbyak::Reg64 reg_src, int nelems,
size_t offt) {
Expand All @@ -526,7 +447,7 @@ class diff_data_kernel_t : jit_generator {
using namespace Xbyak;

preamble();
#define PARAM_OFF(x) offsetof(ker_args, x)
#define PARAM_OFF(x) offsetof(ker_args_t, x)
mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
mov(reg_diff_src, ptr[reg_param + PARAM_OFF(diff_src)]);
Expand Down Expand Up @@ -640,10 +561,26 @@ class diff_data_kernel_t : jit_generator {
Xbyak::Ymm ymm_mean = Xbyak::Ymm(15);
};

statistics_kernel_t *jit_statistics_kernel_create(
const layer_normalization_pd_t *pd) {
return mayiuse(avx2) ? new jit_statistics_kernel_t(pd) : nullptr;
}

data_kernel_t *jit_data_kernel_create(const layer_normalization_pd_t *pd) {
return mayiuse(avx2) ? new jit_data_kernel_t(pd) : nullptr;
}

diff_ss_kernel_t *jit_diff_ss_kernel_create(
const layer_normalization_pd_t *pd) {
return mayiuse(avx2) ? new jit_diff_ss_kernel_t(pd) : nullptr;
}

diff_data_kernel_t *jit_diff_data_kernel_create(
const layer_normalization_pd_t *pd) {
return mayiuse(avx2) ? new jit_diff_data_kernel_t(pd) : nullptr;
}

} // namespace lnorm_utils
} // namespace cpu
} // namespace impl
} // namespace dnnl

#endif

// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
Loading

0 comments on commit 9dfa9fd

Please sign in to comment.