Skip to content

Commit

Permalink
cpu: conv: int8: depthwise: add typedef for ic block enum
Browse files Browse the repository at this point in the history
We use an enum to pass flag tracking whether or not ic block is the last one,
and whether or not it is at the end of a spatial row. This commit adds a typedef
to prevent passing an invalid flag value.
  • Loading branch information
kwiersch committed Dec 30, 2018
1 parent cc8fc67 commit 0ae350c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
12 changes: 5 additions & 7 deletions src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::store_output(int ur_w,
}

void jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(
int ur_w, int pad_l, int pad_r, int last_ic_block_flag) {
int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
auto input_offset = [=](int oi, int ii, int ki) {
return jcp.typesize_in
* ((ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l)
Expand Down Expand Up @@ -278,9 +278,8 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw(
}
}

void jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w,
int pad_l, int pad_r, int last_ic_block_flag, bool h_padded)
{
void jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l,
int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
if (jcp.is_depthwise)
return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag);

Expand Down Expand Up @@ -369,9 +368,8 @@ void jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w,
}
}

void jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop(int ur_w,
int pad_l, int pad_r, int last_ic_block_flag)
{
void jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop(
int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
Label kh_label, skip_kh_loop;
Label t_overflow_label, no_t_overflow_label,
b_overflow_label, no_b_overflow_label;
Expand Down
13 changes: 7 additions & 6 deletions src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ struct jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
ker_reg_base_idx = 28,
ker_dw_reg_base_idx = 30,
};
enum {
typedef enum {
no_last_block,
last_ic_block,
last_sp_block,
};
} ic_block_t;

/* data regs */
reg64_t reg_ptr_scales = rax;
Expand Down Expand Up @@ -165,11 +165,12 @@ struct jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
bool maybe_eltwise(int position);
void prepare_output(int ur_w);
void store_output(int ur_w, int last_oc_block_flag);
void compute_ker_dw(int ur_w, int pad_l, int pad_r, int last_ic_block_flag);
void compute_ker(int ur_w, int pad_l, int pad_r, int last_ic_block_flag,
bool h_padded = false);
void compute_ker_dw(
int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
void compute_ker(int ur_w, int pad_l, int pad_r,
ic_block_t last_ic_block_flag, bool h_padded = false);
void compute_eltwise(int ur_w);
void kh_loop(int ur_w, int pad_l, int pad_r, int last_ic_block_flag);
void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
void icb_loop(
int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
void generate();
Expand Down

0 comments on commit 0ae350c

Please sign in to comment.