Skip to content

Commit

Permalink
returned old behavior for fp32 avx2 1x1 conv with dw conv fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
antonvor committed Mar 26, 2021
1 parent b549701 commit fdf5370
Show file tree
Hide file tree
Showing 18 changed files with 2,355 additions and 25 deletions.
9 changes: 9 additions & 0 deletions include/dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,15 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
const_dnnl_post_ops_t post_ops, int index, float *scale,
dnnl_alg_kind_t *alg_kind, float *alpha, float *beta);

/** Appends DW convolution post operation to the @p post_ops with given parameters
* @p weights and @p bias.
*
* The kind of this post operation is #dnnl_convolution.
*/
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_conv(
dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
const float* weights_data, const float* biases_data);

/// Appends a depthwise post-op convolution with stride 1.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
Expand Down
7 changes: 7 additions & 0 deletions include/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2661,6 +2661,13 @@ struct post_ops : public handle<dnnl_post_ops_t> {
error::wrap_c_api(dnnl_post_ops_append_binarization(get(), convert_to_c(alg), weights_data, output_mask),
"could not append binarization");
}

void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
const float* weights_data, const float* biases_data) {
error::wrap_c_api(dnnl_post_ops_append_dw_conv(get(),
in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data),
"could not append dw conv");
}
};

/// @cond DO_NOT_DOCUMENT_THIS
Expand Down
5 changes: 4 additions & 1 deletion src/common/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ status_t conv_desc_init(convolution_desc_t *conv_desc, prop_kind_t prop_kind,
int ker_range = 1 + (ker - 1) * (dil + 1);

if (str < 1) return invalid_arguments;
consistency = consistency && dil >= 0 && pad_l >= 0 && pad_r + str > 0
consistency = consistency
&& dil >= 0
&& pad_l >= 0
// && pad_r + str > 0 // TODO: [dmitrygo] Commented as WA to support dw conv fusing
&& (src - ker_range + pad_l + pad_r) / str + 1 == dst;
}
if (!consistency) return invalid_arguments;
Expand Down
3 changes: 2 additions & 1 deletion src/common/convolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ struct convolution_fwd_pd_t : public convolution_pd_t {
}

int n_inputs() const override {
return 2 + with_bias() + attr_post_op_dw_inputs()
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
return 2 + with_bias() /* + attr_post_op_dw_inputs() */
+ n_binary_po_inputs();
}

Expand Down
2 changes: 2 additions & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ enum {
// even though they are not in alphabetical order
key_nested,
key_nested_multiple,
key_dw_conv_buffer,
key_dw_conv_padded_bias,
};

enum {
Expand Down
33 changes: 33 additions & 0 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,28 @@ status_t post_ops_t::append_binarization(alg_kind_t alg, const float* weights_da
return success;
}

status_t post_ops_t::append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
dnnl::impl::data_type_t in_dt,
const float* weights_data,
const float* biases_data) {
if (len() == post_ops_limit) return out_of_memory;

entry_.emplace_back();
auto &e = entry_.back();
e.kind = primitive_kind::convolution;
e.depthwise_conv_old.in_h = in_h;
e.depthwise_conv_old.in_w = in_w;
e.depthwise_conv_old.ker_h = ker_h;
e.depthwise_conv_old.ker_w = ker_w;
e.depthwise_conv_old.str_h = str_h;
e.depthwise_conv_old.str_w = str_w;
e.depthwise_conv_old.in_dt = in_dt;
e.depthwise_conv_old.weights_data = weights_data;
e.depthwise_conv_old.biases_data = biases_data;

return success;
}

bool post_ops_t::defined() const {
for (int idx = 0; idx < len(); ++idx) {
auto kind = entry_[idx].kind;
Expand Down Expand Up @@ -758,6 +780,17 @@ status_t dnnl_post_ops_append_binarization(post_ops_t *post_ops, alg_kind_t kind
return post_ops->append_binarization(kind, weights_data, output_mask_data);
}

status_t dnnl_post_ops_append_dw_conv(post_ops_t *post_ops,
int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
dnnl::impl::data_type_t in_dt,
const float* weights_data,
const float* biases_data) {
if (post_ops == nullptr)
return invalid_arguments;

return post_ops->append_dw_conv(in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data);
}

status_t dnnl_primitive_attr_set_rnn_data_qparams(
primitive_attr_t *attr, const float scale, const float shift) {
if (attr == nullptr) return invalid_arguments;
Expand Down
72 changes: 50 additions & 22 deletions src/common/primitive_attr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,19 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {
dnnl::impl::alg_kind_t alg;
const float* weights_data;
const float* output_mask_data;
} ;
};

struct depthwise_conv_old_t {
int in_h;
int in_w;
int ker_h;
int ker_w;
int str_h;
int str_w;
dnnl::impl::data_type_t in_dt;
const float* weights_data;
const float* biases_data;
};

dnnl::impl::primitive_kind_t kind
= dnnl::impl::primitive_kind::undefined;
Expand All @@ -433,6 +445,7 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {
} sum;
eltwise_t eltwise;
depthwise_conv_t depthwise_conv;
depthwise_conv_old_t depthwise_conv_old;
binary_t binary;
depthwise_t depthwise;
quantization_t quantization;
Expand Down Expand Up @@ -501,22 +514,31 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {
break;
case primitive_kind::convolution:
// Depthwise Only
ret = depthwise_conv.stride == rhs.depthwise_conv.stride
&& depthwise_conv.wei_dt
== rhs.depthwise_conv.wei_dt
&& depthwise_conv.bias_dt
== rhs.depthwise_conv.bias_dt
&& depthwise_conv.dst_dt
== rhs.depthwise_conv.dst_dt
&& depthwise_conv.count == rhs.depthwise_conv.count
&& depthwise_conv.mask == rhs.depthwise_conv.mask;
if (!ret) break;
for (int i = 0; i < depthwise_conv.count; ++i) {
ret = ret
&& depthwise_conv.scales[i]
== rhs.depthwise_conv.scales[i];
if (!ret) break;
}
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
// ret = depthwise_conv.stride == rhs.depthwise_conv.stride
// && depthwise_conv.wei_dt
// == rhs.depthwise_conv.wei_dt
// && depthwise_conv.bias_dt
// == rhs.depthwise_conv.bias_dt
// && depthwise_conv.dst_dt
// == rhs.depthwise_conv.dst_dt
// && depthwise_conv.count == rhs.depthwise_conv.count
// && depthwise_conv.mask == rhs.depthwise_conv.mask;
// if (!ret) break;
// for (int i = 0; i < depthwise_conv.count; ++i) {
// ret = ret
// && depthwise_conv.scales[i]
// == rhs.depthwise_conv.scales[i];
// if (!ret) break;
// }
// break;
ret = depthwise_conv_old.in_h == rhs.depthwise_conv_old.in_h
&& depthwise_conv_old.in_w == rhs.depthwise_conv_old.in_w
&& depthwise_conv_old.ker_h == rhs.depthwise_conv_old.ker_h
&& depthwise_conv_old.ker_w == rhs.depthwise_conv_old.ker_w
&& depthwise_conv_old.str_h == rhs.depthwise_conv_old.str_h
&& depthwise_conv_old.str_w == rhs.depthwise_conv_old.str_w
&& depthwise_conv_old.in_dt == rhs.depthwise_conv_old.in_dt;
break;
case primitive_kind::binary:
ret = binary.alg == rhs.binary.alg
Expand Down Expand Up @@ -554,8 +576,9 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {

private:
void clear() {
if (is_convolution() && depthwise_conv.scales)
dnnl::impl::free(depthwise_conv.scales);
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
// if (is_convolution() && depthwise_conv.scales)
// dnnl::impl::free(depthwise_conv.scales);
depthwise_conv.scales = nullptr;
return;
}
Expand All @@ -566,9 +589,10 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {
// else if(is_relu()) {} seems to be unreliable. memcpying for now.
dnnl::impl::utils::array_copy(
(char *)this, (char *)&other, sizeof(*this));
if (other.is_convolution()) {
return set_depthwise_scales(other.depthwise_conv.scales);
}
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
// if (other.is_convolution()) {
// return set_depthwise_scales(other.depthwise_conv.scales);
// }
return dnnl::impl::status::success;
}
};
Expand All @@ -595,6 +619,10 @@ struct dnnl_post_ops : public dnnl::impl::c_compatible {
const void* output_scale, const void* output_shift);
dnnl::impl::status_t append_binarization(dnnl::impl::alg_kind_t alg, const float* weights_data,
const float* output_mask_data);
dnnl::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
dnnl::impl::data_type_t in_dt,
const float* weights_data,
const float* biases_data);

int find(dnnl::impl::primitive_kind_t kind, int start = 0,
int stop = -1) const {
Expand Down
4 changes: 4 additions & 0 deletions src/common/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ void attr2str(char *str, int len, int written, const primitive_attr_t *attr) {
const char *alg_str = dnnl_alg_kind2str(qt.alg);
DPRINT(str, len, written, "%s;", alg_str);
} break;
case primitive_kind::convolution: {
const char *alg_str = "depthwise_conv_old";
DPRINT(str, len, written, "%s;", alg_str);
} break;
default: assert(!"unsupported post op primitive kind!"); break;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/cpu/cpu_convolution_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using namespace dnnl::impl::cpu::aarch64;
#if DNNL_X64
#include "cpu/x64/gemm_bf16_convolution.hpp"
#include "cpu/x64/jit_avx2_1x1_convolution.hpp"
#include "cpu/x64/jit_avx2_1x1_convolution_with_dw_conv.hpp"
#include "cpu/x64/jit_avx2_convolution.hpp"
#include "cpu/x64/jit_avx512_common_1x1_convolution.hpp"
#include "cpu/x64/jit_avx512_common_convolution.hpp"
Expand Down Expand Up @@ -102,6 +103,7 @@ const std::map<conv_impl_key_t, std::vector<pd_create_f>> impl_list_map {
CPU_INSTANCE_X64(jit_avx2_planar_convolution_fwd_t)
CPU_INSTANCE_X64(jit_avx2_dw_convolution_fwd_t)
CPU_INSTANCE_X64(jit_avx2_fork_dw_convolution_fwd_t)
CPU_INSTANCE_X64(jit_avx2_1x1_convolution_with_dw_conv_fwd_t)
CPU_INSTANCE_X64(jit_avx2_1x1_convolution_fwd_t)
CPU_INSTANCE_X64(jit_sse41_dw_convolution_fwd_t)
CPU_INSTANCE_X64(jit_sse41_fork_dw_convolution_fwd_t)
Expand All @@ -110,10 +112,10 @@ const std::map<conv_impl_key_t, std::vector<pd_create_f>> impl_list_map {
CPU_INSTANCE_X64(jit_sse41_convolution_fwd_t)
#ifdef ENABLE_UNUSED_PRIM
CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t)
CPU_INSTANCE(ref_fused_convolution_fwd_t)
#endif
CPU_INSTANCE(gemm_convolution_fwd_t)
CPU_INSTANCE(ref_convolution_fwd_t, f32)
CPU_INSTANCE(ref_fused_convolution_fwd_t)
nullptr,
}},
{{forward, bf16, bf16, f32}, {
Expand Down
Loading

0 comments on commit fdf5370

Please sign in to comment.