Skip to content

Commit

Permalink
src: ocl: rnn: to divide states in workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
skazakov1 committed Jul 20, 2019
1 parent be69a71 commit af38389
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 59 deletions.
12 changes: 7 additions & 5 deletions src/ocl/rnn/cell_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/*
* Common for RNN and LSTM cell execution
*/
#include "ref_rnn.hpp"
#include "ocl/rnn/ref_rnn.hpp"

namespace mkldnn {
namespace impl {
Expand All @@ -33,6 +33,7 @@ cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execu
const rnn_conf_t &rnn_conf = this->pd()->rnn_conf_;

if (aprop == prop_kind::forward) {

AOC<size_t, 3> off_weights_i(weights_input, n_layer, n_dir,
n_parts_weights_layer);
AOC<size_t, 3> off_weights_st(weights_states, n_layer, n_dir,
Expand All @@ -43,10 +44,10 @@ cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execu

cl_ulong offset_states = (cl_ulong)(ws_states_offset_
+ OFF4(lay + 1, n_layer + 1, dir, n_dir, iter, n_iter + 1,
0, batch * n_states * rnn_conf.states_ws_ld));
0, batch * rnn_conf.states_ws_ld));
cl_ulong offset_input = (cl_ulong)(ws_states_offset_
+ OFF4(lay, n_layer + 1, dir, n_dir, iter + 1, n_iter + 1,
0, batch * n_states * rnn_conf.states_ws_ld));
0, batch * rnn_conf.states_ws_ld));
cl_ulong offset_gates = (cl_ulong)(ws_gates_offset_
+ OFF4(lay, n_layer, dir, n_dir, iter, n_iter,
0, batch * rnn_conf.gates_ws_ld));
Expand All @@ -61,6 +62,7 @@ cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execu
gemm_iter);
(this->*elemwise_func)(ctx, dir, lay, iter, dic, wic, batch, workspace,
bias);

} else { // backward

AOC<size_t, 3> off_weights_i(weights_input, n_layer, n_dir,
Expand Down Expand Up @@ -98,7 +100,7 @@ cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execu
n_gates * dic, slc, workspace, ws_gates_offset_ + OFF4(lay, n_layer,
dir, n_dir, iter, n_iter, 0, batch * rnn_conf.gates_ws_ld),
workspace, ws_states_offset_ + OFF4(lay, n_layer + 1, dir,
n_dir, iter + 1, n_iter + 1, 0, n_states * batch
n_dir, iter + 1, n_iter + 1, 0, batch
* rnn_conf.states_ws_ld),
diff_weights_layer, OFF3(lay, n_layer, dir, n_dir, 0,
rnn_conf.diff_weights_layer_nld * rnn_conf.diff_weights_layer_ld),
Expand All @@ -108,7 +110,7 @@ cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execu
n_gates * dic, sic, workspace, ws_gates_offset_ + OFF4(lay, n_layer,
dir, n_dir, iter, n_iter, 0, batch * rnn_conf.gates_ws_ld),
workspace, ws_states_offset_ + OFF4(lay + 1, n_layer + 1, dir,
n_dir, iter, n_iter + 1, 0, n_states * batch
n_dir, iter, n_iter + 1, 0, batch
* rnn_conf.states_ws_ld),
diff_weights_iter, OFF3(lay, n_layer, dir, n_dir, 0,
rnn_conf.diff_weights_iter_nld * rnn_conf.diff_weights_iter_ld),
Expand Down
4 changes: 2 additions & 2 deletions src/ocl/rnn/jit_ref_rnn_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ struct jit_ref_rnn_kernel {

kernel_ctx.define_int("WS_GATES_OFFSET", jrnn.ws_gates_offset);
kernel_ctx.define_int("WS_STATES_OFFSET", jrnn.ws_states_offset);
kernel_ctx.define_int("WS_C_STATE_OFFSET", jrnn.ws_c_state_offset);
kernel_ctx.define_int(
"WS_DIFF_STATES_OFFSET", jrnn.ws_diff_states_offset);
"WS_DIFF_STATES_OFFSET", jrnn.ws_diff_states_offset);
kernel_ctx.define_int("WS_GRID_COMP_OFFSET", jrnn.ws_grid_comp_offset);
kernel_ctx.define_int("WS_CELL_COMP_OFFSET", jrnn.ws_cell_comp_offset);

kernel_ctx.define_int("STATES_WS_LD", jrnn.states_ws_ld);
kernel_ctx.define_int("GATES_WS_LD", jrnn.gates_ws_ld);
kernel_ctx.define_int("DEBUGPRINT", DEBUGPRINT);
Expand Down
2 changes: 1 addition & 1 deletion src/ocl/rnn/ocl_rnn_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "ocl/ocl_engine.hpp"
#include "rnn_utils.hpp"
#include "ocl/rnn/rnn_utils.hpp"

namespace mkldnn {
namespace impl {
Expand Down
11 changes: 2 additions & 9 deletions src/ocl/rnn/ref_postgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
* limitations under the License.
*******************************************************************************/

#include "ref_rnn.hpp"
#include "../cl_executor.hpp"
#include "ocl/rnn/ref_rnn.hpp"

namespace mkldnn {
namespace impl {
Expand All @@ -25,13 +24,10 @@ template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
elemwise_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::rnn_elemwise)) {
auto *compute_stream
= utils::downcast<compute::compute_stream_t *>(ctx.stream());

auto nd_range = compute::nd_range_t({ batch, dic });

const compute::kernel_t &kernel = (aprop == prop_kind::forward)
? elemwise_fwd_kernel_
: elemwise_bwd_kernel_;

compute::kernel_arg_list_t arg_list;
arg_list.set(0, dir);
arg_list.set(1, lay);
Expand All @@ -46,16 +42,13 @@ template elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);


template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
elemwise_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::lstm_elemwise))
{
elemwise_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::lstm_elemwise)) {
auto *compute_stream
= utils::downcast<compute::compute_stream_t *>(ctx.stream());

auto nd_range = compute::nd_range_t({ batch, dic });
const compute::kernel_t &kernel = (aprop == prop_kind::forward)
? elemwise_fwd_kernel_
: elemwise_bwd_kernel_;

compute::kernel_arg_list_t arg_list;
arg_list.set(0, dir);
arg_list.set(1, lay);
Expand Down
94 changes: 59 additions & 35 deletions src/ocl/rnn/ref_rnn.cl
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@
#define OFF2(i0,D0,i1,D1) \
((i0)*(D1)+(i1))

#define OFF_WS_STATES_LAYER(i0,i1,i2,i3,i4) \
OFF5((i0), N_DIR, (i1), N_ITER + 1, (i2), N_STATES, (i3), BATCH, \
(i4), STATES_WS_LD)
#define OFF_WS_STATES(i0,i1,i2,i3,i4,i5) \
OFF6((i0), N_LAYER + 1, (i1), N_DIR, (i2), N_ITER + 1, (i3), N_STATES, \
(i4), BATCH, (i5), STATES_WS_LD)
// used for the both H- and C-states
#define OFF_WS_STATE(i0,i1,i2,i3,i4) \
OFF5((i0), N_LAYER + 1, (i1), N_DIR, (i2), N_ITER + 1, \
(i3), BATCH, (i4), STATES_WS_LD)

#define OFF_WS_DIFF_STATES(i0,i1,i2,i3,i4,i5) \
OFF6((i0), N_LAYER + 1,(i1), N_DIR, (i2), N_ITER + 1, (i3), N_STATES + 1, \
(i4), BATCH, (i5), STATES_WS_LD)
Expand All @@ -55,7 +54,7 @@
// for cell - shorter forms

#define CELL_WS_GATES(i3,i4,i5) OFF_WS_GATES(0,0,0,i3,i4,i5)
#define CELL_WS_STATES(i3,i4,i5) OFF_WS_STATES(0,0,0,i3,i4,i5)
#define CELL_WS_STATE(i4,i5) OFF_WS_STATE(0,0,0,i4,i5)
#define CELL_WS_DIFF_STATES(i3,i4,i5) OFF_WS_DIFF_STATES(0,0,0,i3,i4,i5)

#define OFF_KER_BIAS(i0,i1) \
Expand Down Expand Up @@ -218,6 +217,7 @@ __kernel void ref_rnn_copy_init_layer_kernel(__global DATA_T *ws,
__global DATA_T *src_base, int lr, int rl) {

#if IS_FWD

const int it = get_global_id(2);
const int b = get_global_id(1);
const int c = get_global_id(0);
Expand All @@ -226,14 +226,16 @@ __kernel void ref_rnn_copy_init_layer_kernel(__global DATA_T *ws,
__global DATA_T *src = src_base + SRC_L_OFF(it, 0, 0 ) + b * SLC + c;

if (lr) {
dst = dst_base + OFF_WS_STATES_LAYER(0, it+1, 0, b, c);
dst = dst_base + OFF_WS_STATE(0, 0, it+1, b, c);
dst[0] = src[0];
}
if (rl) {
dst = dst_base + OFF_WS_STATES_LAYER(N_DIR-1, N_ITER-it, 0, b, c);
dst = dst_base + OFF_WS_STATE(0, N_DIR-1, N_ITER-it, b, c);
dst[0] = src[0];
}

#else

const int it = get_global_id(1);
const int b = get_global_id(0);

Expand Down Expand Up @@ -279,12 +281,13 @@ __kernel void ref_rnn_copy_init_iter_kernel(__global DATA_T *ws,
#if IS_FWD
__global DATA_T *dst = ws + WS_STATES_OFFSET;
if (s < SIC)
dst[OFF_WS_STATES(lay + 1, dir, 0, 0, b, s)] = src_base
dst[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_base
? src_base[SRC_I_OFF(lay, dir, b, s)]
: 0.0f;
#if WITH_SRC_ITER_C
__global DATA_T *dst_c = ws + WS_C_STATE_OFFSET;
if (s < DIC)
dst[OFF_WS_STATES(lay + 1, dir, 0, 1, b, s)] = src_c_base
dst_c[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_c_base
? src_c_base[SRC_I_C_OFF(lay, dir, b, s)]
: 0.0f;
#endif
Expand Down Expand Up @@ -315,16 +318,16 @@ __kernel void ref_rnn_copy_res_layer_kernel(__global DATA_T *ws,
int dir = 0;
if (lr) {
dst_base[DST_L_OFF(it, b, dir * DIC + s)] =
src_base[OFF_WS_STATES(N_LAYER, dir, it+1, 0, b, s)];
src_base[OFF_WS_STATE(N_LAYER, dir, it+1, b, s)];
dir = 1;
}
if (rl) {
#if DIRECTION_KIND == SUM
dst_base[DST_L_OFF(it, b, s)] +=
src_base[OFF_WS_STATES(N_LAYER, dir, N_ITER - it, 0, b, s)];
src_base[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)];
#else
dst_base[DST_L_OFF(it, b, dir * DIC + s)] =
src_base[OFF_WS_STATES(N_LAYER, dir, N_ITER - it, 0, b, s)];
src_base[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)];
#endif
}
#else // BWD
Expand Down Expand Up @@ -356,12 +359,13 @@ __kernel void ref_rnn_copy_res_iter_kernel(__global DATA_T *ws,
__global DATA_T *src_base = ws + WS_STATES_OFFSET;
if (dst_base && s < DIC) {
dst_base[DST_I_OFF(lay, dir, b, s)] =
src_base[OFF_WS_STATES(lay + 1, dir, N_ITER, 0, b, s)];
src_base[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)];
}
#if WITH_DST_ITER_C
__global DATA_T *src_c_base = ws + WS_C_STATE_OFFSET;
if (dst_c_base && s < DIC) {
dst_c_base[DST_I_C_OFF(lay, dir, b, s)] =
src_base[OFF_WS_STATES(lay + 1, dir, N_ITER, 1, b, s)];
src_c_base[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)];
}
#endif
#else
Expand Down Expand Up @@ -408,24 +412,41 @@ __kernel void ref_rnn_ws_print_kernel(
}
}
}
printf("ws_states: off %d\n", WS_STATES_OFFSET);
printf("[lay,dir,iter,state]\n");

printf("ws_states (H): off %d\n", WS_STATES_OFFSET);
printf("[lay,dir,iter]\n");
wt = ws + WS_STATES_OFFSET;
for (int j = 0; j < N_LAYER+1; j++) {
for (int dir = 0; dir < N_DIR; dir++) {
for (int i = 0; i < N_ITER+1; i++) {
for (int st = 0; st < N_STATES; st++) {
printf("[%d,%d,%d,%d] : ", j, dir, i, st);
printf("[%d,%d,%d] : ", j, dir, i);
for (int b = 0; b < BATCH; b++) {
for (int s = 0; s < WIC; s++) {
printf(" %f", *(wt + OFF_WS_STATES(j,dir,i,st,b,s)));
printf(" %f", *(wt + OFF_WS_STATE(j,dir,i,b,s)));
}
}
printf("\n");
}
}
}

printf("ws_states (C): off %d\n", WS_C_STATE_OFFSET);
printf("[lay,dir,iter]\n");
wt = ws + WS_C_STATE_OFFSET;
for (int j = 0; j < N_LAYER+1; j++) {
for (int dir = 0; dir < N_DIR; dir++) {
for (int i = 0; i < N_ITER+1; i++) {
printf("[%d,%d,%d] : ", j, dir, i);
for (int b = 0; b < BATCH; b++) {
for (int s = 0; s < WIC; s++) {
printf(" %f", *(wt + OFF_WS_STATE(j,dir,i,b,s)));
}
}
printf("\n");
}
}
}

printf("ws_diff_states: off %d\n",WS_DIFF_STATES_OFFSET);
printf("[lay,dir,state,iter]\n");
wt = ws + WS_DIFF_STATES_OFFSET;
Expand Down Expand Up @@ -453,13 +474,16 @@ __kernel void ref_rnn_elemwise_fwd_kernel(int dir, int lay, int iter,
const int i = get_global_id(0); // batch
const int j = get_global_id(1); // dic

const __global DATA_T *states_tm1_l = ws + WS_STATES_OFFSET
+ OFF_WS_STATES(lay + 1, dir, iter, 0, 0, 0);
const __global DATA_T *c_states_tm1_l = ws + WS_C_STATE_OFFSET
+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0);
const __global DATA_T *bias = bias_base + BIAS_OFF(lay, dir, 0, 0);
__global DATA_T *ws_gates = ws + WS_GATES_OFFSET
+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0);
__global DATA_T *states_t_l = ws + WS_STATES_OFFSET
+ OFF_WS_STATES(lay + 1, dir, iter + 1, 0, 0, 0);

__global DATA_T *h_states_t_l = ws + WS_STATES_OFFSET
+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0);
__global DATA_T *c_states_t_l = ws + WS_C_STATE_OFFSET
+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0);

#if CELL_KIND == VANILLA_LSTM

Expand All @@ -477,18 +501,18 @@ __kernel void ref_rnn_elemwise_fwd_kernel(int dir, int lay, int iter,
ws_gates[CELL_WS_GATES(i, 2, j)] = g_z;
ws_gates[CELL_WS_GATES(i, 3, j)] = g_o;

float Ct = g_f * states_tm1_l[CELL_WS_STATES(1, i, j)] + g_i * g_z;
float Ct = g_f * c_states_tm1_l[CELL_WS_STATE(i, j)] + g_i * g_z;
float Ht = g_o * tanh_fwd(Ct);

states_t_l[CELL_WS_STATES(0, i, j)] = Ht;
states_t_l[CELL_WS_STATES(1, i, j)] = Ct;
h_states_t_l[CELL_WS_STATE(i, j)] = Ht;
c_states_t_l[CELL_WS_STATE(i, j)] = Ct;

#elif CELL_KIND == VANILLA_RNN
float g = activation_fwd((float)ws_gates[CELL_WS_GATES(i, 0, j)]
+ bias[OFF_KER_BIAS(0, j)], 0, 0);

ws_gates[CELL_WS_GATES(i, 0, j)] = g;
states_t_l[CELL_WS_STATES(0, i, j)] = g;
h_states_t_l[CELL_WS_STATE(i, j)] = g;
#else
#error "Wrong cell kind"
#endif
Expand All @@ -502,18 +526,18 @@ __kernel void ref_rnn_elemwise_bwd_kernel(int dir, int lay, int iter,
#if CELL_KIND == VANILLA_LSTM
__global DATA_T *ws_gates = ws + WS_GATES_OFFSET
+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0);
__global DATA_T *states_t_l = ws + WS_STATES_OFFSET
+ OFF_WS_STATES(lay + 1, dir, iter + 1, 0, 0, 0);
__global DATA_T *states_tm1_l = ws + WS_STATES_OFFSET
+ OFF_WS_STATES(lay + 1, dir, iter, 0, 0, 0);
__global DATA_T *c_states_t_l = ws + WS_C_STATE_OFFSET
+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0);
__global DATA_T *c_states_tm1_l = ws + WS_C_STATE_OFFSET
+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0);
__global DATA_T *diff_states_t_l = ws + WS_DIFF_STATES_OFFSET
+ OFF_WS_DIFF_STATES(lay, dir, iter, 0, 0, 0);
__global DATA_T *diff_states_tp1_l = ws + WS_DIFF_STATES_OFFSET
+ OFF_WS_DIFF_STATES(lay, dir, iter + 1, 0, 0, 0);
__global DATA_T *diff_states_t_lp1 = ws + WS_DIFF_STATES_OFFSET
+ OFF_WS_DIFF_STATES(lay + 1, dir, iter, 0, 0, 0);

float Ct = states_t_l[CELL_WS_STATES(1, i, j)];
float Ct = c_states_t_l[CELL_WS_STATE(i, j)];
/// @todo save it in the workspace in fwd pass or recompute it to
/// save bw
float tanhCt = tanh_fwd(Ct);
Expand All @@ -524,7 +548,7 @@ __kernel void ref_rnn_elemwise_bwd_kernel(int dir, int lay, int iter,
+ one_m_square(tanhCt) * ws_gates[CELL_WS_GATES(i, 3, j)]
* dHt;

float dG1 = (float)states_tm1_l[CELL_WS_STATES(1, i, j)] * dCt
float dG1 = (float)c_states_tm1_l[CELL_WS_STATE(i, j)] * dCt
* x_m_square((float)ws_gates[CELL_WS_GATES(i, 1, j)]);
float dG0 = (float)ws_gates[CELL_WS_GATES(i, 2, j)] * dCt
* x_m_square((float)ws_gates[CELL_WS_GATES(i, 0, j)]);
Expand Down
11 changes: 7 additions & 4 deletions src/ocl/rnn/ref_rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
#include "mkldnn_thread.hpp"
#include "mkldnn_traits.hpp"
#include "type_helpers.hpp"
#include "ref_rnn.hpp"
#include "../cl_executor.hpp"
#include "ocl/rnn/ref_rnn.hpp"

namespace mkldnn {
namespace impl {
Expand Down Expand Up @@ -537,11 +536,15 @@ status_t _ref_rnn_common_t<aprop, src_type, weights_type>::execute_(

#if WS_NAN_FILLING
if(rnn_conf.is_fwd) {
DPRINT("DEBUG ws set: (offset, size) states: %ld %ld gates: %ld %ld\n",
ws_states_offset_, rnn_conf.ws_states_size, ws_gates_offset_,
DPRINT("DEBUG ws NaN filling: (offset, size) states: %ld %ld c_states: %ld %ld gates: %ld %ld\n",
ws_states_offset_, rnn_conf.ws_states_size, ws_c_states_offset_, rnn_conf.ws_c_states_size, ws_gates_offset_,
rnn_conf.ws_gates_size);
ws_set(compute_stream, workspace_, ws_states_offset_, NAN,
rnn_conf.ws_states_size);
if (rnn->with_src_iter_c()) {
ws_set(compute_stream, workspace_, ws_c_states_offset_, NAN,
rnn_conf.ws_c_states_size);
}
ws_set(compute_stream, workspace_, ws_gates_offset_, NAN,
rnn_conf.ws_gates_size);
}
Expand Down
9 changes: 6 additions & 3 deletions src/ocl/rnn/rnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ void rnn_utils::set_rnn_conf(rnn_conf_t &rnn_conf, const rnn_desc_t &rd,

rnn_conf.use_workspace = rnn_conf.is_training;
rnn_conf.ws_states_size = (size_t)(rnn_conf.n_layer + 1) * rnn_conf.n_dir
* (rnn_conf.n_iter + 1) * rnn_conf.n_states * rnn_conf.mb
* rnn_conf.states_ws_ld;
rnn_conf.ws_c_states_size = 0;
* (rnn_conf.n_iter + 1) * rnn_conf.mb * rnn_conf.states_ws_ld;
bool is_lstm = rd.cell_kind == mkldnn_vanilla_lstm;
rnn_conf.ws_c_states_size = is_lstm
? (size_t)(rnn_conf.n_layer + 1) * rnn_conf.n_dir
* (rnn_conf.n_iter + 1) * rnn_conf.mb * rnn_conf.states_ws_ld
: 0;
rnn_conf.ws_diff_states_size = rnn_conf.is_training
? (size_t)(rnn_conf.n_layer + 1) * rnn_conf.n_dir * (rnn_conf.n_iter + 1)
* (rnn_conf.n_states + 1) * rnn_conf.mb * rnn_conf.states_ws_ld
Expand Down

0 comments on commit af38389

Please sign in to comment.