Skip to content

Commit

Permalink
cpu: conv: int8: depthwise: use bool for true/false flagging
Browse files Browse the repository at this point in the history
  • Loading branch information
kwiersch committed Dec 30, 2018
1 parent 0ae350c commit f79efef
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
17 changes: 8 additions & 9 deletions src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::compute_eltwise(int ur_w) {
k * jcp.ur_w + ur_w);
}

void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(int ur_w,
int last_oc_block_flag)
{
void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(
int ur_w, bool last_oc_block_flag) {
int nb_oc_block
= jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
Expand All @@ -134,7 +133,7 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(int ur_w,
}

for (int k = 0; k < nb_oc_block; k++) {
const bool mask_flag = last_oc_block_flag == 1 && k == nb_oc_block - 1;
const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
if (jcp.with_bias) {
int bias_offset = jcp.typesize_bia * k * oc_block;
Expand Down Expand Up @@ -169,7 +168,7 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(int ur_w,
if (maybe_eltwise(0)) compute_eltwise(ur_w);
if (p_sum_scale) { // post_op: sum
for (int k = 0; k < nb_oc_block; k++) {
const bool mask_flag = last_oc_block_flag == 1 && k == nb_oc_block - 1;
const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
for (int j = 0; j < ur_w; j++) {
int aux_output_offset
= jcp.typesize_out
Expand All @@ -188,7 +187,7 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(int ur_w,
if (maybe_eltwise(1)) compute_eltwise(ur_w);

for (int k = 0; k < nb_oc_block; k++) {
const bool mask_flag = last_oc_block_flag == 1 && k == nb_oc_block - 1;
const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
for (int j = 0; j < ur_w; j++) {
Zmm zmm = zmm_out(j, k);
if (jcp.dst_dt == data_type::u8) {
Expand Down Expand Up @@ -477,15 +476,15 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop(

jne(common_store, T_NEAR);

store_output(ur_w, 1);
store_output(ur_w, true); // last oc block
jmp(end_store, T_NEAR);

L(common_store);
store_output(ur_w, 0);
store_output(ur_w, false);

L(end_store);
} else {
store_output(ur_w, 0);
store_output(ur_w, false);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {

bool maybe_eltwise(int position);
void prepare_output(int ur_w);
void store_output(int ur_w, int last_oc_block_flag);
void store_output(int ur_w, bool last_oc_block_flag);
void compute_ker_dw(
int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
void compute_ker(int ur_w, int pad_l, int pad_r,
Expand Down

0 comments on commit f79efef

Please sign in to comment.