Skip to content

Commit

Permalink
gpu: intel: ocl: rnn: rename ref_ kernels with simple_ prefix for RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
h-sadia committed Jul 23, 2024
1 parent 1932716 commit 67ceefd
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 77 deletions.
4 changes: 2 additions & 2 deletions src/gpu/gpu_rnn_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ using namespace dnnl::impl::prop_kind;
const std::map<pk_impl_key_t, std::vector<impl_list_item_t>>
impl_list_map REG_RNN_P({
{{forward}, {
GPU_INSTANCE_INTEL(intel::ocl::ref_rnn_fwd_t)
GPU_INSTANCE_INTEL(intel::ocl::simple_rnn_fwd_t)
nullptr,
}},
{{backward}, REG_BWD_PK({
GPU_INSTANCE_INTEL(intel::ocl::ref_rnn_bwd_t)
GPU_INSTANCE_INTEL(intel::ocl::simple_rnn_bwd_t)
nullptr,
})},
});
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/intel/ocl/rnn/cell_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ status_t compute_cell_fwd(const exec_ctx_t &ctx,
}

template <prop_kind_t aprop>
cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution)) {
cell_execution_sig((_simple_rnn_common_t<aprop>::cell_execution)) {
const conf_t &rnn = this->pd()->rnn_conf;
const ocl_conf_t &ocl_conf = this->pd()->ocl_conf;
const rnn_offsets_t &offsets = this->pd()->off;
Expand Down Expand Up @@ -226,8 +226,8 @@ cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution)) {
}
return status::success;
}
template cell_execution_sig(ref_rnn_fwd_t::cell_execution);
template cell_execution_sig(ref_rnn_bwd_t::cell_execution);
template cell_execution_sig(simple_rnn_fwd_t::cell_execution);
template cell_execution_sig(simple_rnn_bwd_t::cell_execution);
} // namespace ocl
} // namespace intel
} // namespace gpu
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/intel/ocl/rnn/cell_gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using namespace dnnl::impl::utils;
using namespace rnn_utils;

template <prop_kind_t aprop>
cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution_gru)) {
cell_execution_sig((_simple_rnn_common_t<aprop>::cell_execution_gru)) {
const conf_t &rnn = this->pd()->rnn_conf;
const ocl_conf_t &ocl_conf = this->pd()->ocl_conf;
const rnn_offsets_t &offsets = this->pd()->off;
Expand Down Expand Up @@ -152,8 +152,8 @@ cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution_gru)) {
}
return status::success;
}
template cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru);
template cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru);
template cell_execution_sig(simple_rnn_fwd_t::cell_execution_gru);
template cell_execution_sig(simple_rnn_bwd_t::cell_execution_gru);
} // namespace ocl
} // namespace intel
} // namespace gpu
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/intel/ocl/rnn/cell_gru_lbr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ namespace ocl {
using namespace dnnl::impl::utils;
using namespace rnn_utils;

template cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru_lbr);
template cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru_lbr);
template cell_execution_sig(simple_rnn_fwd_t::cell_execution_gru_lbr);
template cell_execution_sig(simple_rnn_bwd_t::cell_execution_gru_lbr);

template <prop_kind_t aprop>
cell_execution_sig((_ref_rnn_common_t<aprop>::cell_execution_gru_lbr)) {
cell_execution_sig((_simple_rnn_common_t<aprop>::cell_execution_gru_lbr)) {
const conf_t &rnn = this->pd()->rnn_conf;
const ocl_conf_t &ocl_conf = this->pd()->ocl_conf;
const rnn_offsets_t &offsets = this->pd()->off;
Expand Down
20 changes: 10 additions & 10 deletions src/gpu/intel/ocl/rnn/ref_rnn.cl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ float activation_bwd(float s, float alpha, float cliping) {
}

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_copy_init_layer(__global WS_STATE_DATA_T *dst_base,
simple_rnn_copy_init_layer(__global WS_STATE_DATA_T *dst_base,
__global char *src_base, __global AUX_DATA_T *scratch_diff_states,
int lr, int rl, int batch, int dhc, int slc, int n_iter, int n_layer,
int n_dir, int n_states, int states_ws_ld, int scratch_diff_states_ld,
Expand Down Expand Up @@ -199,7 +199,7 @@ ref_rnn_copy_init_layer(__global WS_STATE_DATA_T *dst_base,
#endif
}

__kernel void ref_rnn_copy_init_iter(__global WS_STATE_DATA_T *dst_base,
__kernel void simple_rnn_copy_init_iter(__global WS_STATE_DATA_T *dst_base,
__global AUX_DATA_T *dst_c_base, __global char *src_base,
__global char *src_c_base, __global AUX_DATA_T *scratch_diff_states,
int batch, int dhc, int sic, int n_iter, int n_layer, int n_dir,
Expand Down Expand Up @@ -273,7 +273,7 @@ __kernel void ref_rnn_copy_init_iter(__global WS_STATE_DATA_T *dst_base,
}

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_copy_res_layer(
simple_rnn_copy_res_layer(
__global WS_STATE_DATA_T *src_base, __global char *dst_base,
__global AUX_DATA_T *scratch_diff_states, int lr, int rl, int batch,
int dhc, int slc, int n_iter, int n_layer, int n_dir, int n_states,
Expand Down Expand Up @@ -363,7 +363,7 @@ ref_rnn_copy_res_layer(
#endif
}

__kernel void ref_rnn_copy_res_iter(
__kernel void simple_rnn_copy_res_iter(
__global WS_STATE_DATA_T *src_base, __global AUX_DATA_T *src_c_base,
__global char *dst_base, __global char *dst_c_base,
__global AUX_DATA_T *scratch_diff_states, int batch, int dhc, int sic,
Expand Down Expand Up @@ -578,7 +578,7 @@ float deq_w(ACC_DATA_T s, int gate, int j, __global float *scales,

// for int8 LSTM
__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_elemwise_fwd(int dir, int lay, int iter,
simple_rnn_elemwise_fwd(int dir, int lay, int iter,
__global ACC_DATA_T *scratch_gates_, dim_t scratch_gates_off,
__global float *scales, float alpha, float data_shift, float data_scale,
__global float *tm_scales, __global WS_STATE_DATA_T *h_states_t_l_,
Expand Down Expand Up @@ -631,7 +631,7 @@ ref_rnn_elemwise_fwd(int dir, int lay, int iter,
#else

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_elemwise_fwd(__global ACC_DATA_T *scratch_gates_,
simple_rnn_elemwise_fwd(__global ACC_DATA_T *scratch_gates_,
dim_t scratch_gates_off, __global BIAS_DATA_T *bias_, dim_t bias_off,
float alpha, __global float *tm_scales,
__global WS_STATE_DATA_T *h_states_t_l_, dim_t h_states_t_l_off,
Expand Down Expand Up @@ -768,7 +768,7 @@ ref_rnn_elemwise_fwd(__global ACC_DATA_T *scratch_gates_,
// same memory when sizeof(SRC_DATA_T) == sizeof(AUX_DATA_T) or when
// scratch_gates is unused in order to reduce memory usage
__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_elemwise_bwd(int dir, int lay, int iter,
simple_rnn_elemwise_bwd(int dir, int lay, int iter,
__global SRC_DATA_T *scratch_diff_gates_, dim_t scratch_diff_gates_off,
__global AUX_DATA_T *scratch_gates_, dim_t scratch_gates_off,
__global BIAS_DATA_T *bias_, dim_t bias_off, float alpha,
Expand Down Expand Up @@ -1046,7 +1046,7 @@ ref_rnn_elemwise_bwd(int dir, int lay, int iter,
}
}
#else
__kernel void ref_rnn_elemwise_bwd() {}
__kernel void simple_rnn_elemwise_bwd() {}
#endif // !IS_FWD

#if CELL_COMP_ENABLED
Expand Down Expand Up @@ -1384,7 +1384,7 @@ void cell_common(const_wei_layer_cell_t wei_layer,
}

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void
ref_rnn_cell_fwd(__global const WEI_LAYER_DATA_T *wei_layer_,
simple_rnn_cell_fwd(__global const WEI_LAYER_DATA_T *wei_layer_,
dim_t wei_layer_off, int64x5_t wei_layer_strides_,
__global const WEI_ITER_DATA_T *wei_iter_, dim_t wei_iter_off,
int64x5_t wei_iter_strides_, __global const AUX_DATA_T *cell_layer_,
Expand Down Expand Up @@ -1493,5 +1493,5 @@ ref_rnn_cell_fwd(__global const WEI_LAYER_DATA_T *wei_layer_,
}

#else
__kernel void ref_rnn_cell_fwd() {}
__kernel void simple_rnn_cell_fwd() {}
#endif
62 changes: 31 additions & 31 deletions src/gpu/intel/ocl/rnn/ref_rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ inline status_t init_ocl_conf<prop_kind::backward>(
}

template <>
status_t _ref_rnn_common_t<prop_kind::forward>::pd_t::set_default_params() {
status_t _simple_rnn_common_t<prop_kind::forward>::pd_t::set_default_params() {
using namespace format_tag;
if (src_layer_md_.format_kind == format_kind::any)
CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
Expand Down Expand Up @@ -525,7 +525,7 @@ status_t _ref_rnn_common_t<prop_kind::forward>::pd_t::set_default_params() {
}

template <>
status_t _ref_rnn_common_t<prop_kind::backward>::pd_t::set_default_params() {
status_t _simple_rnn_common_t<prop_kind::backward>::pd_t::set_default_params() {
using namespace format_tag;
int arch_ld = is_xe_hpc ? 128 : 64;
if (src_layer_md_.format_kind == format_kind::any)
Expand Down Expand Up @@ -600,7 +600,7 @@ status_t _ref_rnn_common_t<prop_kind::backward>::pd_t::set_default_params() {
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::pd_t::init(impl::engine_t *engine) {
status_t _simple_rnn_common_t<aprop>::pd_t::init(impl::engine_t *engine) {
using namespace prop_kind;
using namespace utils;
using namespace rnn_utils;
Expand Down Expand Up @@ -966,7 +966,7 @@ status_t _ref_rnn_common_t<aprop>::pd_t::init(impl::engine_t *engine) {
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::init(impl::engine_t *engine) {
status_t _simple_rnn_common_t<aprop>::init(impl::engine_t *engine) {
using namespace rnn_utils;

switch (pd()->cell_kind()) {
Expand Down Expand Up @@ -1055,7 +1055,7 @@ status_t _ref_rnn_common_t<aprop>::init(impl::engine_t *engine) {
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::init_res_storage(
status_t _simple_rnn_common_t<aprop>::init_res_storage(
impl::engine_t *engine, gpu_resource_t *r) const {
if (pd()->rnn_conf.is_int8 && pd()->rnn_conf.copy_bias) {
dim_t size = pd()->rnn_conf.n_gates * pd()->rnn_conf.dhc
Expand Down Expand Up @@ -1098,7 +1098,7 @@ status_t _ref_rnn_common_t<aprop>::init_res_storage(
}

template <prop_kind_t aprop>
gemm_sig((_ref_rnn_common_t<aprop>::gemm_primitive)) {
gemm_sig((_simple_rnn_common_t<aprop>::gemm_primitive)) {
// We flip A and B here since the GEMM API is row major but the
// RNN code describes GEMM in column major fashion
gemm_exec_args_t gemm_args;
Expand Down Expand Up @@ -1178,7 +1178,7 @@ gemm_sig((_ref_rnn_common_t<aprop>::gemm_primitive)) {

//*************** Grid computations strategy: linear ***************//
template <prop_kind_t aprop>
grid_execution_sig((_ref_rnn_common_t<aprop>::linear_execution)) {
grid_execution_sig((_simple_rnn_common_t<aprop>::linear_execution)) {
const conf_t &rnn = pd()->rnn_conf;
dim_t n_layer = rnn.n_layer;
dim_t n_dir = rnn.n_dir;
Expand Down Expand Up @@ -1264,7 +1264,7 @@ grid_execution_sig((_ref_rnn_common_t<aprop>::linear_execution)) {
//********* GRID computations strategy: utility functions **********//

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::bias_prepare(const exec_ctx_t &ctx,
status_t _simple_rnn_common_t<aprop>::bias_prepare(const exec_ctx_t &ctx,
compute::compute_stream_t *compute_stream, dim_t n_layer, dim_t n_dir,
dim_t n_bias, dim_t n_gates, dim_t dhc, const memory_storage_t &ws_bias,
const memory_storage_t &scales, const memory_storage_t &wei_layer,
Expand Down Expand Up @@ -1297,7 +1297,7 @@ status_t _ref_rnn_common_t<aprop>::bias_prepare(const exec_ctx_t &ctx,
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::copy_init_layer(const exec_ctx_t &ctx,
status_t _simple_rnn_common_t<aprop>::copy_init_layer(const exec_ctx_t &ctx,
compute::compute_stream_t *compute_stream, bool lr, bool rl,
dim_t batch, dim_t dhc, dim_t slc, dim_t n_iter, dim_t n_layer,
dim_t n_dir, dim_t n_states, dim_t states_ws_ld,
Expand Down Expand Up @@ -1356,7 +1356,7 @@ status_t _ref_rnn_common_t<aprop>::copy_init_layer(const exec_ctx_t &ctx,
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::copy_init_iter(const exec_ctx_t &ctx,
status_t _simple_rnn_common_t<aprop>::copy_init_iter(const exec_ctx_t &ctx,
compute::compute_stream_t *compute_stream, dim_t batch, dim_t dhc,
dim_t sic, dim_t n_iter, dim_t n_layer, dim_t n_dir, dim_t n_states,
dim_t states_ws_ld, dim_t scratch_diff_states_ld,
Expand Down Expand Up @@ -1430,7 +1430,7 @@ status_t _ref_rnn_common_t<aprop>::copy_init_iter(const exec_ctx_t &ctx,
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::copy_res_layer(const exec_ctx_t &ctx,
status_t _simple_rnn_common_t<aprop>::copy_res_layer(const exec_ctx_t &ctx,
compute::compute_stream_t *compute_stream, bool lr, bool rl,
dim_t batch, dim_t dhc, dim_t slc, dim_t n_iter, dim_t n_layer,
dim_t n_dir, dim_t n_states, dim_t states_ws_ld,
Expand Down Expand Up @@ -1492,7 +1492,7 @@ status_t _ref_rnn_common_t<aprop>::copy_res_layer(const exec_ctx_t &ctx,
}

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::copy_res_iter(const exec_ctx_t &ctx,
status_t _simple_rnn_common_t<aprop>::copy_res_iter(const exec_ctx_t &ctx,
compute::compute_stream_t *compute_stream, dim_t batch, dim_t dhc,
dim_t sic, dim_t n_iter, dim_t n_layer, dim_t n_dir, dim_t n_states,
dim_t states_ws_ld, dim_t scratch_diff_states_ld,
Expand Down Expand Up @@ -1569,7 +1569,7 @@ status_t _ref_rnn_common_t<aprop>::copy_res_iter(const exec_ctx_t &ctx,
//********************* Execution function *********************//

template <prop_kind_t aprop>
status_t _ref_rnn_common_t<aprop>::execute_(const exec_ctx_t &ctx) const {
status_t _simple_rnn_common_t<aprop>::execute_(const exec_ctx_t &ctx) const {

impl::engine_t *engine = ctx.stream()->engine();
auto *compute_stream
Expand Down Expand Up @@ -1736,40 +1736,40 @@ status_t _ref_rnn_common_t<aprop>::execute_(const exec_ctx_t &ctx) const {
};
// Fix for MSVS warning C4661.
template <>
cell_execution_sig(ref_rnn_fwd_t::cell_execution);
cell_execution_sig(simple_rnn_fwd_t::cell_execution);
template <>
cell_execution_sig(ref_rnn_bwd_t::cell_execution);
cell_execution_sig(simple_rnn_bwd_t::cell_execution);
template <>
cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru);
cell_execution_sig(simple_rnn_fwd_t::cell_execution_gru);
template <>
cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru);
cell_execution_sig(simple_rnn_bwd_t::cell_execution_gru);
template <>
cell_execution_sig(ref_rnn_fwd_t::cell_execution_gru_lbr);
cell_execution_sig(simple_rnn_fwd_t::cell_execution_gru_lbr);
template <>
cell_execution_sig(ref_rnn_bwd_t::cell_execution_gru_lbr);
cell_execution_sig(simple_rnn_bwd_t::cell_execution_gru_lbr);
template <>
elemwise_sig(ref_rnn_fwd_t::rnn_elemwise);
elemwise_sig(simple_rnn_fwd_t::rnn_elemwise);
template <>
elemwise_sig(ref_rnn_bwd_t::rnn_elemwise);
elemwise_sig(simple_rnn_bwd_t::rnn_elemwise);
template <>
elemwise_sig(ref_rnn_fwd_t::lstm_elemwise);
elemwise_sig(simple_rnn_fwd_t::lstm_elemwise);
template <>
elemwise_sig(ref_rnn_bwd_t::lstm_elemwise);
elemwise_sig(simple_rnn_bwd_t::lstm_elemwise);
template <>
elemwise_sig(ref_rnn_fwd_t::lstm_elemwise_u8s8);
elemwise_sig(simple_rnn_fwd_t::lstm_elemwise_u8s8);
template <>
elemwise_sig(ref_rnn_bwd_t::lstm_elemwise_u8s8);
elemwise_sig(simple_rnn_bwd_t::lstm_elemwise_u8s8);
template <>
elemwise_sig_gru_lbr(ref_rnn_fwd_t::gru_lbr_elemwise);
elemwise_sig_gru_lbr(simple_rnn_fwd_t::gru_lbr_elemwise);
template <>
elemwise_sig_gru_lbr(ref_rnn_bwd_t::gru_lbr_elemwise);
elemwise_sig_gru_lbr(simple_rnn_bwd_t::gru_lbr_elemwise);
template <>
elemwise_sig_gru(ref_rnn_fwd_t::gru_elemwise);
elemwise_sig_gru(simple_rnn_fwd_t::gru_elemwise);
template <>
elemwise_sig_gru(ref_rnn_bwd_t::gru_elemwise);
elemwise_sig_gru(simple_rnn_bwd_t::gru_elemwise);

template struct _ref_rnn_common_t<prop_kind::forward>;
template struct _ref_rnn_common_t<prop_kind::backward>;
template struct _simple_rnn_common_t<prop_kind::forward>;
template struct _simple_rnn_common_t<prop_kind::backward>;

} // namespace ocl
} // namespace intel
Expand Down
10 changes: 5 additions & 5 deletions src/gpu/intel/ocl/rnn/ref_rnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ enum gemm_kind_t {
};

template <prop_kind_t aprop>
struct _ref_rnn_common_t : public gpu_primitive_t {
struct _simple_rnn_common_t : public gpu_primitive_t {
using gpu_primitive_t::gpu_primitive_t;

using class_name = _ref_rnn_common_t<aprop>;
using class_name = _simple_rnn_common_t<aprop>;

typedef elemwise_sig((class_name::*elemwise_f));
typedef elemwise_sig_gru((class_name::*elemwise_gru_f));
Expand All @@ -76,7 +76,7 @@ struct _ref_rnn_common_t : public gpu_primitive_t {

pd_t(const pd_t &other) = default;

DECLARE_COMMON_PD_T("ref:any", class_name);
DECLARE_COMMON_PD_T("ocl:simple:any", class_name);

status_t init(impl::engine_t *engine);

Expand Down Expand Up @@ -259,8 +259,8 @@ struct _ref_rnn_common_t : public gpu_primitive_t {

enum { SCALES_ = 0, TM_SCALES_ = 1 };
};
using ref_rnn_fwd_t = _ref_rnn_common_t<prop_kind::forward>;
using ref_rnn_bwd_t = _ref_rnn_common_t<prop_kind::backward>;
using simple_rnn_fwd_t = _simple_rnn_common_t<prop_kind::forward>;
using simple_rnn_bwd_t = _simple_rnn_common_t<prop_kind::backward>;
} // namespace ocl
} // namespace intel
} // namespace gpu
Expand Down
10 changes: 5 additions & 5 deletions src/gpu/intel/ocl/rnn/rnn_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ struct ocl_conf_t {
bundle, get_kernel_names(), kernel_ctx);
}
const std::vector<const char *> &get_kernel_names() const {
static const std::vector<const char *> names
= {"ref_rnn_bias_prepare", "ref_rnn_copy_init_layer",
"ref_rnn_copy_init_iter", "ref_rnn_copy_res_layer",
"ref_rnn_copy_res_iter", "ref_rnn_elemwise_fwd",
"ref_rnn_elemwise_bwd", "ref_rnn_cell_fwd"};
static const std::vector<const char *> names = {"ref_rnn_bias_prepare",
"simple_rnn_copy_init_layer", "simple_rnn_copy_init_iter",
"simple_rnn_copy_res_layer", "simple_rnn_copy_res_iter",
"simple_rnn_elemwise_fwd", "simple_rnn_elemwise_bwd",
"simple_rnn_cell_fwd"};
return names;
}

Expand Down
Loading

0 comments on commit 67ceefd

Please sign in to comment.