Skip to content

Commit

Permalink
api: attr: introduce rounding mode and output scales
Browse files Browse the repository at this point in the history
  • Loading branch information
Fomenko, Evarist M committed Oct 31, 2017
1 parent 920d755 commit bb454cb
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 1 deletion.
71 changes: 71 additions & 0 deletions include/mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,77 @@ mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(
mkldnn_primitive_attr_t attr);

/* Returns integer output rounding mode @p round_mode for a given @p attr,
* previously set by mkldnn_primitive_attr_set_int_output_round_mode. */
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(
const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode);

/* Sets output rounding mode @p round_mode for integer operations for a given
* @p attr.
*
* The default value is #mkldnn_round_nearest.
*/
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(
mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode);

/* Returns @p count, correspondence scale @p mask, and pointer to a constant
* floating point array of output @p scales for given @p attr, previously set
* by mkldnn_primitive_attr_set_output_scales.
*
* @warning
* @scales array points to the internal @p attr field, so user should not
* modify/destroy @p scales.
*
* @warning
* The lifetime of @p scales is same as @p attr it belongs to, so it is
* illegal to use the @p scales after @p attr is destroyed
*/
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(
const_mkldnn_primitive_attr_t attr, int *count, int *mask,
const float **scales);

/* Sets output @p scales for primitive operations. The number of elements @p
* count and correspondence scale @p mask are stored for future use.
*
* The @p mask argument defines correspondence between output tensor dimensions
* and the @p scales array. Set i-th bit of @p mask to 1 to use dedicated
* scaling factor for each slice of the output tensor over i-th dimension. Set
* @p mask to 0 to use common scaling factor for the whole output tensor.
*
* @note
* The dimension order is always native and does not depend on the actual
* layout used. Examples:
* - 2D dimensional data the order of dimensions is always: (n, c)
* - 4D dimensional data the order is always: (n, c, h, w)
* - 5D dimensional weights the order is always: (g, oc, ic, kh, kw)
*
* Example usage:
* @code
* int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
* float scales[oc] = { ... }; // unique output scales per output channel
* int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
*
* mkldnn_convolution_desc_t cd; // create & configure convolution op_desc
*
* mkldnn_primitive_attr_t attr;
* mkldnn_primitive_attr_create(&attr); // create default attributes
* mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
*
* mkldnn_primitive_desc_t cpd;
* mkldnn_primitive_desc_create_v2(&cpd, &cd, attr, NULL);
* @endcode
*
* @note
* There is no way to check that @p count corresponds to @p mask until an
* actual primitive descriptor is created, so it is user's responsibility
* to set proper values. The following formula must be hold:
*
* count == \prod_{d \in mask} output.dims[d]
*/
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(
mkldnn_primitive_attr_t attr, int count, int mask,
const float *scales);

/** @} */

/** @addtogroup c_api_memory Memory
Expand Down
43 changes: 43 additions & 0 deletions include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ inline mkldnn_query_t convert_to_c(query aquery) {
return static_cast<mkldnn_query_t>(aquery);
}

enum round_mode {
round_nearest = mkldnn_round_nearest,
round_down = mkldnn_round_down,
};

inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
return static_cast<mkldnn_round_mode_t>(mode);
}

#ifndef DOXYGEN_SHOULD_SKIP_THIS
template <> struct handle_traits<mkldnn_primitive_attr_t> {
static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
Expand All @@ -244,6 +253,40 @@ struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
"could not create a primitive attr");
reset(result);
}

round_mode get_int_output_round_mode() const {
mkldnn_round_mode_t result;
error::wrap_c_api(mkldnn_primitive_attr_get_int_output_round_mode(
get(), &result), "could not get int output round mode");
return round_mode(result);
}

void set_int_output_round_mode(round_mode mode) {
error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
get(), mkldnn::convert_to_c(mode)),
"could not set int output round mode");
}

void get_output_scales(int &mask, std::vector<float> &scales) const
{
int count, c_mask;
const float *c_scales;
error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
&count, &c_mask, &c_scales),
"could not get int output scales");
scales.resize(count);

mask = c_mask;
for (int c = 0; c < count; ++c)
scales[c] = c_scales[c];
}

void set_output_scales(int mask, const std::vector<float> &scales)
{
error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
(int)scales.size(), mask, &scales[0]),
"could not set int output scales");
}
};

/// An execution engine.
Expand Down
12 changes: 12 additions & 0 deletions include/mkldnn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ typedef enum {
mkldnn_u8 = 6,
} mkldnn_data_type_t;

/** Rounding mode */
typedef enum {
/** Round nearest */
mkldnn_round_nearest = 1,
/** Round down */
mkldnn_round_down = 2,
} mkldnn_round_mode_t;

/** Memory format specification.
*
* Intel(R) MKL-DNN uses the following notation for memory format names:
Expand Down Expand Up @@ -739,6 +747,10 @@ typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;

/** @struct mkldnn_primitive_attr
* @brief An opaque structure for primitive descriptor attributes.
*
* Attributes may contain:
* - rounding mode for integer based primitives (like convolution, reorders)
* - output scales (to scale the result prior to storing it to the memory)
*/
struct mkldnn_primitive_attr;

Expand Down
6 changes: 6 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ namespace data_type {
const data_type_t u8 = mkldnn_u8;
}

using round_mode_t = mkldnn_round_mode_t;
namespace round_mode {
const round_mode_t nearest = mkldnn_round_nearest;
const round_mode_t down = mkldnn_round_down;
}

using memory_format_t = mkldnn_memory_format_t;
namespace memory_format {
const memory_format_t undef = mkldnn_format_undef;
Expand Down
78 changes: 78 additions & 0 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,45 @@ using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::utils;

namespace mkldnn {
namespace impl {

status_t scales_t::set(int count, int mask, const float *scales) {
if (count != count_)
cleanup();

count_ = count;
mask_ = mask;

if (count_ == 1) {
scales_ = scales_buf_;
utils::array_set(scales_, scales[0], scales_buf_size);
} else {
scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
if (scales_ == nullptr)
return status::out_of_memory;

for (int c = 0; c < count_; ++c)
scales_[c] = scales[c];
}

return status::success;
}

}
}

status_t primitive_attr_t::set_round_mode(round_mode_t round_mode) {
using namespace mkldnn::impl::round_mode;

const bool ok = one_of(round_mode, nearest, down);
if (!ok)
return invalid_arguments;

round_mode_ = round_mode;
return success;
}

status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
if (attr == nullptr)
return invalid_arguments;
Expand All @@ -48,3 +87,42 @@ status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {

return success;
}

status_t mkldnn_primitive_attr_get_int_output_round_mode(
const primitive_attr_t *attr, round_mode_t *round_mode) {
if (any_null(attr, round_mode))
return invalid_arguments;

*round_mode = attr->round_mode_;

return success;
}

status_t mkldnn_primitive_attr_set_int_output_round_mode(
primitive_attr_t *attr, round_mode_t round_mode) {
if (any_null(attr))
return invalid_arguments;

return attr->set_round_mode(round_mode);
}

status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
int *count, int *mask, const float **scales) {
if (any_null(attr, count, mask, scales))
return invalid_arguments;

*count = attr->output_scales_.count_;
*mask = attr->output_scales_.mask_;
*scales = attr->output_scales_.scales_;

return success;
}

status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
int count, int mask, const float *scales) {
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
if (!ok)
return invalid_arguments;

return attr->output_scales_.set(count, mask, scales);
}
54 changes: 53 additions & 1 deletion src/common/primitive_attr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,63 @@
#include "utils.hpp"
#include "c_types_map.hpp"

namespace mkldnn {
namespace impl {

struct scales_t: public c_compatible {
scales_t(): count_(1), mask_(0), scales_(scales_buf_)
{ set(1.); }

scales_t(const scales_t &rhs): scales_t()
{ set(rhs.count_, rhs.mask_, rhs.scales_); }

~scales_t() { cleanup(); }

scales_t &operator=(const scales_t &rhs) {
if (&rhs == this)
return *this;
status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
assert(status == status::success);
(void)status;
return *this;
}

status_t set(int count, int mask, const float *scales);
status_t set(float single_scale) { return this->set(1, 0, &single_scale); }

int count_;
int mask_;
float *scales_;

private:
enum { scales_buf_size = 16 };
alignas(64) float scales_buf_[scales_buf_size];

void cleanup() {
if (scales_ != scales_buf_ && scales_ != nullptr)
impl::free(scales_);

count_ = 1;
mask_ = 0;
scales_ = scales_buf_;
}
};

}
}

struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
mkldnn_primitive_attr() {}
mkldnn_primitive_attr()
: round_mode_(mkldnn::impl::round_mode::nearest) {}

mkldnn_primitive_attr *clone() const
{ return new mkldnn_primitive_attr(*this); }

mkldnn::impl::status_t set_round_mode(
mkldnn::impl::round_mode_t round_mode);

mkldnn::impl::round_mode_t round_mode_;
mkldnn::impl::scales_t output_scales_;
};

#endif
1 change: 1 addition & 0 deletions tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}
#file(GLOB API_TEST_CASES_SRC api_tests/*.cpp)
#file(GLOB PRIM_TEST_CASES_SRC test_*.cpp)
file(GLOB PRIM_TEST_CASES_SRC
test_iface_attr.cpp
test_sum.cpp
test_reorder.cpp
test_concat.cpp
Expand Down
Loading

0 comments on commit bb454cb

Please sign in to comment.