Skip to content

Commit

Permalink
cpu: x64: reorder: fix the computation of blk_chunk_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen authored and vpirogov committed Jun 25, 2021
1 parent c3ab4da commit 95df54c
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/cpu/x64/jit_uni_reorder_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,28 @@ static status_t compute_blk_and_tail(
return status::success;
}

static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_,
const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) {
const auto imd = memory_desc_wrapper(imd_);
const auto omd = memory_desc_wrapper(omd_);
const auto &ibd = imd.blocking_desc();
const auto &obd = omd.blocking_desc();
if (p.ip_tail == 0 && p.op_tail == 0) return status::success;

const ptrdiff_t is
= ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]];
const ptrdiff_t os = obd.strides[blk_idx];

for (int i = blk_idx; i < omd.ndims(); ++i) {
if (p.nodes[i].os == os && p.nodes[i].is == is) {
chunk_idx = i;
return status::success;
}
}

return status::invalid_arguments;
}

status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) {
const auto md = memory_desc_wrapper(md_);
Expand Down Expand Up @@ -203,7 +225,6 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,

int i_pos = 0; /* state for input -- current dimension */
int o_pos = 0; /* state for output -- current dimension */
int blk_chunk_idx = 0;

while (i_pos < ild.ndims && o_pos < old.ndims) {
assert(ild.id[i_pos] == old.id[o_pos]);
Expand All @@ -226,7 +247,6 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
p.nodes[ndims].is = ild.strides[i_pos];
p.nodes[ndims].os = old.strides[o_pos] * factor;
p.nodes[ndims].ss = ss[o_pos] * factor;
blk_chunk_idx = op_padding[o_pos] > 0 ? ndims : blk_chunk_idx;
++ndims;
++i_pos;
old.dims[o_pos] = factor;
Expand All @@ -237,12 +257,14 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
p.nodes[ndims].is = ild.strides[i_pos] * factor;
p.nodes[ndims].os = old.strides[o_pos];
p.nodes[ndims].ss = ss[o_pos];
blk_chunk_idx = ip_padding[i_pos] > 0 ? ndims : blk_chunk_idx;
++ndims;
++o_pos;
ild.dims[i_pos] = factor;
}
}
int blk_chunk_idx = ndims;
CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx));

p.ndims = ndims;
p.full_ndims = ndims;
p.blk_chunk_idx = blk_chunk_idx;
Expand Down

0 comments on commit 95df54c

Please sign in to comment.