Skip to content

Commit

Permalink
cpu: fix x8s8bf16 zero-point case for gemm convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 28, 2022
1 parent 9b0c031 commit 75b2d7d
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 19 deletions.
8 changes: 5 additions & 3 deletions src/cpu/gemm_x8s8s32x_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2021 Intel Corporation
* Copyright 2017-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -188,10 +188,12 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr,
&& jit_gemm_convolution_utils::padding_exists(jcp);
const bool should_apply_zp_src_comp_pad_jit_pp
= should_apply_zp_src_comp_pad
&& gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
&& gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(
dst_md.data_type());
const bool should_apply_zp_src_comp_outside_pp
= should_apply_zp_src_comp_pad
&& !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
&& !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(
dst_md.data_type());

dim_t g {0}, n {0}, ohb {0}, owb {0};
dim_t start = 0, end = 0;
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/gemm_x8s8s32x_convolution_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,9 +164,9 @@ bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) {
return post_ops_ok(post_ops, &dst_md);
}

bool mayiuse_jit_pp_kernel() noexcept {
bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept {
#if DNNL_X64
return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel();
return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(dst_dt);
#else
return false;
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/gemm_x8s8s32x_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,7 +51,7 @@ struct pp_ker_t {

bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d);
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d);
bool mayiuse_jit_pp_kernel() noexcept;
bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept;

} // namespace gemm_x8s8s32x_convolution_utils
} // namespace cpu
Expand Down
12 changes: 6 additions & 6 deletions src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -714,14 +714,14 @@ void jit_pp_ker_t::generate() {
if (jcp_.with_eltwise) postops_injector_->prepare_table();
}

bool mayiuse_jit_pp_kernel() noexcept {
return mayiuse(avx512_core);
bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept {
const auto is_bf16_dst_dt = dst_dt == data_type::bf16;
return mayiuse(avx512_core) && !is_bf16_dst_dt;
}

pp_ker_t *jit_pp_ker_create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
const auto is_bf16_dst_dt = pd->dst_md()->data_type == data_type::bf16;
return mayiuse_jit_pp_kernel() && !is_bf16_dst_dt
return mayiuse_jit_pp_kernel(pd->dst_md()->data_type)
? new jit_pp_ker_t(pd, jcp)
: nullptr;
}
Expand All @@ -730,7 +730,7 @@ bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) {
using namespace x64::injector;
static constexpr bool sum_at_pos_0_only = true;
static constexpr bool sum_requires_scale_one = false;
return mayiuse_jit_pp_kernel()
return mayiuse_jit_pp_kernel(dst_d->data_type())
&& dnnl::impl::cpu::x64::injector::post_ops_ok(
{avx512_core, {binary, eltwise, sum}, post_ops, dst_d,
sum_at_pos_0_only, sum_requires_scale_one});
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +28,7 @@ namespace gemm_x8s8s32x_convolution_utils {

cpu::gemm_x8s8s32x_convolution_utils::pp_ker_t *jit_pp_ker_create(
const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
bool mayiuse_jit_pp_kernel() noexcept;
bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept;
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d);

} // namespace gemm_x8s8s32x_convolution_utils
Expand Down
6 changes: 3 additions & 3 deletions tests/benchdnn/inputs/conv/harness_conv_attrs_int8_asymmetric
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
--mb=2
--dir=FWD_B
--attr-oscale=per_oc:2.25
--attr-zero-points=src:common:2+dst:common:1
--attr-zero-points=src:common:-2+dst:common:1
--cfg=u8s8f32,s8s8f32 --batch=shapes_googlenet_v2
--cfg=u8s8s32 --batch=shapes_3d
--cfg=u8s8s32 --batch=shapes_gemm
--cfg=u8s8s32,s8s8bf16 --batch=shapes_gemm
--attr-post-ops=sum:1:0:s8
--cfg=u8s8u8,s8s8u8 --batch=shapes_vgg_19

Expand All @@ -17,6 +17,6 @@
--cfg=u8s8s8,s8s8s8 --batch=shapes_googlenet_v3
--cfg=u8s8s32 --batch=shapes_alexnet
--attr-zero-points=src:common:1*+dst:common:1*
--cfg=s8s8s32 --batch=shapes_alexnet --batch=shapes_3d
--cfg=s8s8s32,u8s8bf16 --batch=shapes_alexnet --batch=shapes_3d
--cfg=u8s8s32 --batch=shapes_gemm

0 comments on commit 75b2d7d

Please sign in to comment.