Skip to content

Commit

Permalink
common: add spatial broadcast strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun committed May 1, 2024
1 parent f3eafbe commit 51d056e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/common/broadcast_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
broadcasting_strategy_t::per_mb,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w, broadcasting_strategy_t::per_w,
broadcasting_strategy_t::batch,
broadcasting_strategy_t::batch, broadcasting_strategy_t::spatial,
broadcasting_strategy_t::no_broadcast};

return get_rhs_arg_broadcasting_strategy(
Expand Down Expand Up @@ -147,6 +147,24 @@ bool is_batch_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
return batch_bcast;
}

// Checks if mask corresponds to broadcast per mb and oc dimensions
// Returns true if mask (5D) is equal to [0, 0, 1, 1, 1]
bool is_spatial_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
const memory_desc_wrapper &dst_d) {
if (!dst_d.is_plain()) return false; // blocked format not supported

bool spatial_bcast = !mask.test(0) && !mask.test(1);
if (!spatial_bcast) return false;

const size_t ndims = dst_d.ndims();
assert(ndims > 0);

for (size_t d = 2; d < ndims; ++d)
spatial_bcast = spatial_bcast && mask.test(d);

return spatial_bcast;
}

bool bcast_strategy_enabled(const bcast_set_t &supported_strategy_set,
const broadcasting_strategy_t &bcast) {
return supported_strategy_set.find(bcast) != supported_strategy_set.cend();
Expand Down Expand Up @@ -233,6 +251,9 @@ broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
else if (is_batch_bcast(mask, dst_d, output_dims)
&& is_enabled(broadcasting_strategy_t::batch))
bcast = broadcasting_strategy_t::batch;
else if (is_spatial_bcast(mask, dst_d)
&& is_enabled(broadcasting_strategy_t::spatial))
bcast = broadcasting_strategy_t::spatial;
else if (is_enabled(broadcasting_strategy_t::shared_axes))
bcast = broadcasting_strategy_t::shared_axes;

Expand Down
1 change: 1 addition & 0 deletions src/common/broadcast_strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class broadcasting_strategy_t {
per_w, // [1, 1, 1, 1, w] // Broadcast per width
shared_axes, // [n, 1, d, h, 1] // General case broadcast (any combination)
batch, // [1, c, d, h, w] // Broadcast only batch
spatial, // [n, c, 1, 1, 1] // Broadcast spatial dimensions
no_broadcast, // [n, c, d, h, w]
unsupported
};
Expand Down

0 comments on commit 51d056e

Please sign in to comment.