Skip to content

Commit

Permalink
api: support auto dispatching of conv algorithm
Browse files Browse the repository at this point in the history
- breaks ABI w.r.t previous commit.
  • Loading branch information
shelleygoel committed Dec 28, 2018
1 parent a6e933a commit df14434
Show file tree
Hide file tree
Showing 33 changed files with 382 additions and 84 deletions.
1 change: 1 addition & 0 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {

enum algorithm {
algorithm_undef = mkldnn_alg_kind_undef,
convolution_auto = mkldnn_convolution_auto,
convolution_direct = mkldnn_convolution_direct,
convolution_winograd = mkldnn_convolution_winograd,
deconvolution_direct = mkldnn_deconvolution_direct,
Expand Down
53 changes: 28 additions & 25 deletions include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,50 +440,52 @@ typedef enum {
typedef enum {
mkldnn_alg_kind_undef,
/** Direct convolution */
mkldnn_convolution_direct = 1,
mkldnn_convolution_direct = 0x1,
/** Winograd convolution */
mkldnn_convolution_winograd = 2,
mkldnn_convolution_winograd = 0x2,
/** Convolution algorithm(either direct or Winograd) is chosen just in time **/
mkldnn_convolution_auto = 0x3,
/** Direct deconvolution */
mkldnn_deconvolution_direct = 0xa,
/** Winograd deconvolution */
mkldnn_deconvolution_winograd = 0xb,
/** Eltwise: ReLU */
mkldnn_eltwise_relu = 8,
mkldnn_eltwise_relu = 0x1f,
/** Eltwise: hyperbolic tangent non-linearity (tanh) */
mkldnn_eltwise_tanh = 9,
mkldnn_eltwise_tanh = 0x2f,
/** Eltwise: parametric exponential linear unit (elu) */
mkldnn_eltwise_elu = 10,
mkldnn_eltwise_elu = 0x3f,
/** Eltwise: square */
mkldnn_eltwise_square = 11,
mkldnn_eltwise_square = 0x4f,
/** Eltwise: abs */
mkldnn_eltwise_abs = 12,
mkldnn_eltwise_abs = 0x5f,
/** Eltwise: square root */
mkldnn_eltwise_sqrt = 13,
mkldnn_eltwise_sqrt = 0x6f,
/** Eltwise: linear */
mkldnn_eltwise_linear = 14,
mkldnn_eltwise_linear = 0x7f,
/** Eltwise: bounded_relu */
mkldnn_eltwise_bounded_relu = 15,
mkldnn_eltwise_bounded_relu = 0x8f,
/** Eltwise: soft_relu */
mkldnn_eltwise_soft_relu = 16,
mkldnn_eltwise_soft_relu = 0x9f,
/** Eltwise: logistic */
mkldnn_eltwise_logistic = 17,
mkldnn_eltwise_logistic = 0xaf,
/** Max pooling */
mkldnn_pooling_max = 34,
mkldnn_pooling_max = 0x1ff,
/** Average pooling include padding */
mkldnn_pooling_avg_include_padding = 40,
mkldnn_pooling_avg_include_padding = 0x2ff,
/** Average pooling exclude padding */
mkldnn_pooling_avg_exclude_padding = 41,
mkldnn_pooling_avg_exclude_padding = 0x3ff,
mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
/** Local response normalization (LRN) across multiple channels */
mkldnn_lrn_across_channels = 65,
mkldnn_lrn_across_channels = 0xaff,
/** LRN within a single channel */
mkldnn_lrn_within_channel = 66,
/** Direct deconvolution */
mkldnn_deconvolution_direct = 71,
/** Winograd deconvolution */
mkldnn_deconvolution_winograd = 72,
mkldnn_lrn_within_channel = 0xbff,
/** RNN cell */
mkldnn_vanilla_rnn = 80,
mkldnn_vanilla_rnn = 0x1fff,
/** LSTM cell */
mkldnn_vanilla_lstm = 81,
mkldnn_vanilla_lstm = 0x2fff,
/** GRU cell */
mkldnn_vanilla_gru = 82,
mkldnn_vanilla_gru = 0x3fff,
/** GRU cell with linear before reset
*
* Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
Expand All @@ -492,7 +494,7 @@ typedef enum {
* Primitive expects 4 biases on input:
* \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
* */
mkldnn_gru_linear_before_reset = 83,
mkldnn_gru_linear_before_reset = 0x4fff,
} mkldnn_alg_kind_t;

/** Flags for batch-normalization primititve. */
Expand Down Expand Up @@ -1193,3 +1195,4 @@ typedef const struct mkldnn_stream *const_mkldnn_stream_t;


#endif

1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ namespace prop_kind {
using alg_kind_t = mkldnn_alg_kind_t;
namespace alg_kind {
const alg_kind_t undef = mkldnn_alg_kind_undef;
const alg_kind_t convolution_auto = mkldnn_convolution_auto;
const alg_kind_t convolution_direct = mkldnn_convolution_direct;
const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
Expand Down
2 changes: 1 addition & 1 deletion src/common/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ status_t conv_desc_init(convolution_desc_t *conv_desc,
bool args_ok = true
&& !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
padding_l)
&& one_of(alg_kind, convolution_direct, convolution_winograd)
&& one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
&& one_of(padding_kind, padding_kind::padding_zero);
if (!args_ok) return invalid_arguments;

Expand Down
19 changes: 19 additions & 0 deletions src/common/convolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,19 @@ struct convolution_fwd_pd_t: public primitive_desc_t {

inline int ndims() const { return desc_.src_desc.ndims; }

virtual status_t set_alg_kind(alg_kind_t alg) {
if (alg == alg_kind::undef) return status::invalid_arguments;
desc_.alg_kind = alg;
return status::success;
}

bool has_zero_dim_memory() const {
return false
|| memory_desc_wrapper(desc_.src_desc).has_zero_dim()
|| memory_desc_wrapper(desc_.dst_desc).has_zero_dim();
}


protected:
convolution_desc_t desc_;
const convolution_fwd_pd_t *hint_fwd_pd_;
Expand Down Expand Up @@ -249,6 +256,12 @@ struct convolution_bwd_data_pd_t: public primitive_desc_t {
inline int ndims() const { return desc_.diff_src_desc.ndims; }
virtual bool support_bias() const { return false; }

virtual status_t set_alg_kind(alg_kind_t alg) {
if (alg == alg_kind::undef) return status::invalid_arguments;
desc_.alg_kind = alg;
return status::success;
}

bool has_zero_dim_memory() const {
return false
|| memory_desc_wrapper(desc_.diff_src_desc).has_zero_dim()
Expand Down Expand Up @@ -363,6 +376,12 @@ struct convolution_bwd_weights_pd_t: public primitive_desc_t {

inline int ndims() const { return desc_.src_desc.ndims; }

virtual status_t set_alg_kind(alg_kind_t alg) {
if (alg == alg_kind::undef) return status::invalid_arguments;
desc_.alg_kind = alg;
return status::success;
}

bool has_zero_dim_memory() const {
return false
|| memory_desc_wrapper(desc_.src_desc).has_zero_dim()
Expand Down
1 change: 1 addition & 0 deletions src/common/mkldnn_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {

const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
if (v == mkldnn_alg_kind_undef) return "undef";
if (v == mkldnn_convolution_auto) return "convolution_auto";
if (v == mkldnn_convolution_direct) return "convolution_direct";
if (v == mkldnn_convolution_winograd) return "convolution_winograd";
if (v == mkldnn_eltwise_relu) return "eltwise_relu";
Expand Down
6 changes: 6 additions & 0 deletions src/cpu/cpu_convolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
CHECK(weights_pd_.set_format(wei_format()));
if (bias_pd_.desc()->format == any)
CHECK(bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down Expand Up @@ -157,6 +159,8 @@ struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
CHECK(weights_pd_.set_format(wei_format()));
if (bias_pd_.desc()->format == any)
CHECK(bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down Expand Up @@ -221,6 +225,8 @@ struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
CHECK(diff_weights_pd_.set_format(wei_format()));
if (diff_bias_pd_.desc()->format == any)
CHECK(diff_bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down
16 changes: 13 additions & 3 deletions src/cpu/gemm_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ struct gemm_convolution_fwd_t: public cpu_primitive_t {
&& this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, forward_training,
forward_inference)
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind,
alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
Expand Down Expand Up @@ -97,6 +99,8 @@ struct gemm_convolution_fwd_t: public cpu_primitive_t {
CHECK(this->weights_pd_.set_format(wei_format()));
if (this->bias_pd_.desc()->format == any)
CHECK(this->bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}

Expand Down Expand Up @@ -168,7 +172,8 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
bool ok = true
&& this->set_default_params() == status::success
&& this->desc()->prop_kind == backward_data
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->diff_src_desc.data_type,
Expand Down Expand Up @@ -210,6 +215,8 @@ struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
CHECK(this->diff_dst_pd_.set_format(src_format()));
if (this->weights_pd_.desc()->format == any)
CHECK(this->weights_pd_.set_format(wei_format()));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down Expand Up @@ -257,7 +264,8 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
bool ok = true
&& this->set_default_params() == status::success
&& this->desc()->prop_kind == backward_weights
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
Expand Down Expand Up @@ -303,6 +311,8 @@ struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
CHECK(this->diff_weights_pd_.set_format(wei_format()));
if (this->diff_bias_pd_.desc()->format == any)
CHECK(this->diff_bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down
13 changes: 9 additions & 4 deletions src/cpu/gemm_x8s8s32x_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
&& utils::one_of(this->desc()->prop_kind,
prop_kind::forward_training,
prop_kind::forward_inference)
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind,
alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& this->desc()->src_desc.data_type == src_type
&& this->desc()->dst_desc.data_type == dst_type
Expand Down Expand Up @@ -96,7 +98,8 @@ struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
: (is_sign_input ? hwio_s8s8 : hwio)));
if (this->bias_pd_.desc()->format == any)
CHECK(this->bias_pd_.set_format(x));

if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}

Expand Down Expand Up @@ -214,7 +217,8 @@ struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
bool ok = true
&& this->set_default_params() == status::success
&& this->desc()->prop_kind == prop_kind::backward_data
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& this->desc()->diff_src_desc.data_type == dst_type
&& this->desc()->diff_dst_desc.data_type == u8
Expand Down Expand Up @@ -253,7 +257,8 @@ struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
this->with_groups() ? hwigo : hwio));
if (bias_pd_.desc()->format == any)
CHECK(bias_pd_.set_format(x));

if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down
16 changes: 13 additions & 3 deletions src/cpu/jit_avx2_1x1_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t {
&& this->set_default_params() == status::success
&& utils::one_of(this->desc()->prop_kind, forward_training,
forward_inference)
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind,
alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
Expand Down Expand Up @@ -99,6 +101,8 @@ struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t {
: utils::pick(this->ndims() - 3, OIw8i8o, OIhw8i8o)));
if (this->bias_pd_.desc()->format == any)
CHECK(this->bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down Expand Up @@ -154,7 +158,8 @@ struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t {
bool ok = true
&& this->set_default_params() == status::success
&& this->desc()->prop_kind == backward_data
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->diff_src_desc.data_type,
Expand Down Expand Up @@ -197,6 +202,8 @@ struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t {
CHECK(this->weights_pd_.set_format(this->with_groups()
? utils::pick(this->ndims() - 3, gOIw8o8i, gOIhw8o8i)
: utils::pick(this->ndims() - 3, OIw8o8i, OIhw8o8i)));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}
};
Expand Down Expand Up @@ -257,7 +264,8 @@ struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t {
bool ok = true
&& this->set_default_params() == status::success
&& this->desc()->prop_kind == backward_weights
&& this->desc()->alg_kind == alg_kind::convolution_direct
&& utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
alg_kind::convolution_direct)
&& !this->has_zero_dim_memory()
&& utils::everyone_is(data_type::f32,
this->desc()->src_desc.data_type,
Expand Down Expand Up @@ -315,6 +323,8 @@ struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t {
: utils::pick(this->ndims() - 3, OIw8i8o, OIhw8i8o)));
if (this->diff_bias_pd_.desc()->format == any)
CHECK(this->diff_bias_pd_.set_format(x));
if (this->desc()->alg_kind == alg_kind::convolution_auto)
CHECK(this->set_alg_kind(alg_kind::convolution_direct));
return status::success;
}

Expand Down
Loading

0 comments on commit df14434

Please sign in to comment.