Skip to content

Commit

Permalink
src: ocl: reorder: scale per dimension support
Browse files Browse the repository at this point in the history
  • Loading branch information
envsp committed Jul 25, 2019
1 parent eab4f28 commit 45756b5
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/ocl/jit_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,14 @@ struct jit_simple_sum_conf_t {
/* simple reorder */
struct jit_reorder_conf_t {
bool do_reorder, with_group, has_padding;
bool with_sum_ab, with_sum_a;
bool scale_quant, with_sum_ab, with_sum_a;
bool use_ref_impl;
int ndims;
size_t nelems;
size_t gws_d[3], lws_d[3];
int block[3];
int sub_group_size;
int scale_mask;
};

/* eltwise */
Expand Down
17 changes: 13 additions & 4 deletions src/ocl/jit_simple_reorder_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ struct jit_simple_reorder_kernel {
status_t status = status::success;

const auto &dims = output_md.padded_dims();
jrp.with_sum_ab = (pd->alpha() != 1.f || pd->beta() != 0.f);
jrp.scale_quant = pd->attr()->output_scales_.mask_ != 0;
jrp.scale_mask = jrp.scale_quant ? pd->attr()->output_scales_.mask_ : 0;
jrp.with_sum_ab = jrp.scale_quant
? false
: (pd->alpha() != 1.f || pd->beta() != 0.f);
jrp.with_sum_a = jrp.with_sum_ab && pd->beta() == 0.f;
jrp.do_reorder = jrp.with_sum_ab ? true : input_md != output_md;
jrp.do_reorder = jrp.scale_quant || jrp.with_sum_ab
? true
: input_md != output_md;
jrp.has_padding = !input_md.is_dense() || !output_md.is_dense();
jrp.ndims = input_md.ndims();
jrp.nelems = utils::array_product(dims, jrp.ndims);
Expand Down Expand Up @@ -80,7 +86,7 @@ struct jit_simple_reorder_kernel {
gOIhw2o8i8o2i))
jrp.with_group = 1;

if (jrp.has_padding)
if (jrp.has_padding || jrp.scale_quant)
return status;

const bool type_s8_u8
Expand Down Expand Up @@ -145,7 +151,10 @@ struct jit_simple_reorder_kernel {
const memory_desc_wrapper &output_md) {

kernel_ctx.define_int("NDIMS", jrp.ndims);
if (jrp.with_sum_a)
if (jrp.scale_quant) {
kernel_ctx.define_int("SCALE_QUANT", 1);
kernel_ctx.define_int("SCALE_MASK", jrp.scale_mask);
} else if (jrp.with_sum_a)
kernel_ctx.define_int("WITH_SUM_A", 1);
else if (jrp.with_sum_ab)
kernel_ctx.define_int("WITH_SUM_AB", 1);
Expand Down
15 changes: 14 additions & 1 deletion src/ocl/ocl_cross_engine_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,32 @@ status_t ocl_cross_engine_reorder_t::execute(const exec_ctx_t &ctx) const {
float beta = pd()->beta();
const bool do_reorder = jrp.do_reorder;

status_t status = status::success;
auto ocl_reorder = [&](const memory_storage_t &in_storage,
const memory_storage_t &out_storage) {
if (scales) {
void *tmp_ptr = nullptr;
status = scales->map_data(&tmp_ptr);
if (status != status::success)
return status;
memcpy(tmp_ptr, pd()->attr()->output_scales_.scales_,
pd()->attr()->output_scales_.count_ * sizeof(float));
status = scales->unmap_data(tmp_ptr);
if (status != status::success)
return status;
}

compute::kernel_arg_list_t arg_list;
arg_list.set(0, in_storage);
arg_list.set(1, out_storage);
arg_list.set(2, alpha);
arg_list.set(3, beta);
arg_list.set(4, scales ? *scales : memory_storage_t::empty_storage());

auto nd_range = compute::nd_range_t(jrp.gws_d, jrp.lws_d);
return compute_stream->parallel_for(nd_range, kernel_, arg_list);
};

status_t status = status::success;
if (in_e_kind == engine_kind::gpu && out_e_kind == engine_kind::cpu) {
if (do_reorder) {
status = ocl_reorder(input, *temp_buf);
Expand Down
9 changes: 9 additions & 0 deletions src/ocl/ocl_cross_engine_reorder_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ struct ocl_cross_engine_reorder_t : public primitive_t {
if (!temp_buf)
return status::runtime_error;
}
if (pd()->jrp_.scale_quant) {
size_t size = pd()->attr()->output_scales_.count_ * sizeof(float);
memory_storage_t *scales_ptr;
engine()->create_memory_storage(&scales_ptr, size);
scales.reset(scales_ptr);
if (!scales)
return status::runtime_error;
}

return status::success;
}
Expand All @@ -140,6 +148,7 @@ struct ocl_cross_engine_reorder_t : public primitive_t {
compute::kernel_t kernel_;
jit_simple_reorder_kernel *ker_;
std::unique_ptr<memory_storage_t> temp_buf;
std::unique_ptr<memory_storage_t> scales;
};

} // namespace ocl
Expand Down
53 changes: 51 additions & 2 deletions src/ocl/simple_reorder.cl
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,22 @@ ushort8 float_to_bfloat8(float8 b) {
#define CONVERT_IN_TO_OUT8(x) CONVERT_F32_TO_OUT8(x)
#endif

#if WITH_SUM_A
#if SCALE_QUANT

#define REORDER(_out, _in, _a, _b) \
do { \
const float _x = CONVERT_IN_TO_F32(_in); \
const float _s = _a * _x + _b; \
_out = CONVERT_F32_TO_OUT(_s); \
} while (0)
#define REORDER8(_out, _in, _a, _b) \
do { \
const float8 _x = CONVERT_IN_TO_F32_8(_in); \
const float8 _s = _a * _x + _b; \
_out = CONVERT_F32_TO_OUT8(_s); \
} while (0)

#elif WITH_SUM_A

#define REORDER(_out, _in, _a, _b) \
do { \
Expand Down Expand Up @@ -492,12 +507,37 @@ ushort8 float_to_bfloat8(float8 b) {

#endif // WITH_SUM_AB

#if SCALE_QUANT

#define MASK_D(_d) ((SCALE_MASK >> _d) & 1)

#define SCALE_D0 (MASK_D(0) ? SRC_D0 : 1)
#define SCALE_D1 (MASK_D(1) ? SRC_D1 : 1)
#define SCALE_D2 (MASK_D(2) ? SRC_D2 : 1)
#define SCALE_D3 (MASK_D(3) ? SRC_D3 : 1)
#define SCALE_D4 (MASK_D(4) ? SRC_D4 : 1)
#define SCALE_D5 (MASK_D(5) ? SRC_D5 : 1)

#define SCALE_S0 (SCALE_D1 * SCALE_D2 * SCALE_D3 * SCALE_D4 * SCALE_D5)
#define SCALE_S1 (SCALE_D2 * SCALE_D3 * SCALE_D4 * SCALE_D5)
#define SCALE_S2 (SCALE_D3 * SCALE_D4 * SCALE_D5)
#define SCALE_S3 (SCALE_D4 * SCALE_D5)
#define SCALE_S4 (SCALE_D5)
#define SCALE_S5 (1)

#define SCALE_OFF(x0, x1, x2, x3, x4, x5) \
((x0)*SCALE_S0 * MASK_D(0) + (x1)*SCALE_S1 * MASK_D(1) \
+ (x2)*SCALE_S2 * MASK_D(2) + (x3)*SCALE_S3 * MASK_D(3) \
+ (x4)*SCALE_S4 * MASK_D(4) + (x5)*SCALE_S5 * MASK_D(5))

#endif // SCALE_QUANT

#if SUB_GROUP_SIZE != 1
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
#endif
__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void
any2any_kernel(__global DT_IN *input, __global DT_OUT *output, float alpha,
float beta) {
float beta, __global float *scales) {

input += SRC_OFFSET_PAD;
output += DST_OFFSET_PAD;
Expand All @@ -520,6 +560,9 @@ any2any_kernel(__global DT_IN *input, __global DT_OUT *output, float alpha,
{
const int in_off = IN_OFF(d0, d1, d2, 0, 0, 0);
const int out_off = OUT_OFF(d0, d1, d2, 0, 0, 0);
#if SCALE_QUANT
alpha = scales[SCALE_OFF(d0, d1, d2, 0, 0, 0)];
#endif
REORDER(output[out_off], input[in_off], alpha, beta);
}
#elif NDIMS <= 5
Expand All @@ -540,6 +583,9 @@ any2any_kernel(__global DT_IN *input, __global DT_OUT *output, float alpha,
for (int d4 = 0; d4 < SRC_D4; ++d4) {
const int in_off = IN_OFF(d0, d1, d2, d3, d4, 0);
const int out_off = OUT_OFF(d0, d1, d2, d3, d4, 0);
#if SCALE_QUANT
alpha = scales[SCALE_OFF(d0, d1, d2, d3, d4, 0)];
#endif
REORDER(output[out_off], input[in_off], alpha, beta);
}
#if PAD_FILL_ZERO
Expand Down Expand Up @@ -569,6 +615,9 @@ any2any_kernel(__global DT_IN *input, __global DT_OUT *output, float alpha,
for (int d5 = 0; d5 < DST_D5; ++d5) {
const int in_off = IN_OFF(d0, d1, d2, d3, d4, d5);
const int out_off = OUT_OFF(d0, d1, d2, d3, d4, d5);
#if SCALE_QUANT
alpha = scales[SCALE_OFF(d0, d1, d2, d3, d4, d5)];
#endif
REORDER(output[out_off], input[in_off], alpha, beta);
}
#if PAD_FILL_ZERO
Expand Down
66 changes: 66 additions & 0 deletions tests/benchdnn/inputs/reorder/test_reorder_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,72 @@
--dtag=abcd,aBcd16b,ABcd16a16b
32x32x2x2

### test scale
# axis type
--reset
--sdt=u8
--ddt=u8

--attr=oscale=per_dim_0:0.
--stag=abcd,cdba
--dtag=abcd,cdba
3x5x7x11

--attr=oscale=per_dim_1:0.
--stag=abcd,cdba
--dtag=abcd,cdba
3x5x7x11

--attr=oscale=per_dim_01:0.
--stag=abcd,cdba
--dtag=abcd,cdba
3x5x7x11

# data types
--reset
--sdt=f32,s32,s8,u8,f16,bf16
--ddt=f32,s32,s8,u8,f16,bf16

--attr=oscale=per_dim_1:0.5
--stag=abcd,cdba
--dtag=abcd,cdba
3x5x7x11

# layouts
--reset
--sdt=u8
--ddt=u8

--attr=oscale=per_dim_1:0.5
--stag=a
--dtag=a
128

--attr=oscale=per_dim_1:0.5
--stag=ab,ba
--dtag=ab,ba
7x11

--attr=oscale=per_dim_1:0.5
--stag=abc,bac
--dtag=abc,bac
5x7x11

--attr=oscale=per_dim_1:0.5
--stag=abcd,ABcd16a16b
--dtag=abcd,ABcd16a16b
32x32x2x2

--attr=oscale=per_dim_1:0.5
--stag=abcde,ABcde16a16b
--dtag=abcde,ABcde16a16b
32x32x2x2x2

--attr=oscale=per_dim_1:0.5
--stag=abcdef,aBCdef16b16c
--dtag=abcdef,aBCdef16b16c
2x32x32x2x2x2

### test blocking
--reset
--sdt=f32
Expand Down

0 comments on commit 45756b5

Please sign in to comment.