Skip to content

Commit

Permalink
Manually cache block attributes in PackBMatrix::pack_unpack_ (pytorch…
Browse files Browse the repository at this point in the history
…#702)

Summary:
Pull Request resolved: pytorch#702

I noticed from inspecting assembly that we were re-loading the 4 block attributes on each loop iteration. It initially surprised me that the compiler was not able to prove that the loop body doesn't mutate the block (I don't see any function calls and TBAA should mean that it can prove nothing aliases the block). On reflection, it is notable that `T` is `signed char` -- `char*` is the one exception to the strict aliasing rule (see https://en.cppreference.com/w/cpp/language/reinterpret_cast), so the compiler cannot prove that `unpack_buf` and `pack_buf` do not alias `block` and thus it has to re-load from `block` each time.

Reviewed By: jspark1105

Differential Revision: D30976702

fbshipit-source-id: 0cb47b30aae883aceface0e0a0708504e1254933
  • Loading branch information
swolchok authored and facebook-github-bot committed Sep 17, 2021
1 parent 321db4d commit dccc573
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions src/PackBMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,25 +261,37 @@ void PackBMatrix<T, accT>::pack_unpack_(
assert((block.row_start % BaseType::blockRowSize()) == 0);
assert((block.col_start % BaseType::blockColSize()) == 0);

// When T is char *, type-based alias analysis (TBAA) cannot prove
// that `unpack_buf` and `pack_buf` do not alias `block` (because
// char * is the one exception to the C++ strict aliasing rule), so the
// compiler would have to re-load these attributes from `block` on
// every loop iteration for correctness. We know better, so let's
// help the compiler out by doing the loads ourselves into
// constants.
const auto blockRowStart = block.row_start;
const auto blockRowSize = block.row_size;
const auto blockColStart = block.col_start;
const auto blockColSize = block.col_size;

BaseType::packedBlock(block);
bool tr = (trans_ == matrix_op_t::Transpose);
for (int g = 0; g < BaseType::numGroups(); ++g) {
T* pack_buf_cur = pack_buf +
g * BaseType::packedBufferSize(block.row_size, block.col_size, params);
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
g * BaseType::packedBufferSize(blockRowSize, blockColSize, params);
for (int i = blockRowStart; i < blockRowStart + blockRowSize; ++i) {
int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
(BaseType::blockRowSize() * BaseType::blockColSize()) +
(i % BaseType::blockRowSize() / row_interleave_) *
BaseType::blockColSize() * row_interleave_ +
i % row_interleave_;

int c_start_offset = (block.col_start / BaseType::blockColSize()) *
int c_start_offset = (blockColStart / BaseType::blockColSize()) *
BaseType::blockRowSize() * BaseType::blockColSize() +
(block.col_start % BaseType::blockColSize()) * row_interleave_;
(blockColStart % BaseType::blockColSize()) * row_interleave_;

int c_idx_offset = 0;
int c_blk_offset = 0;
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
for (int j = blockColStart; j < blockColStart + blockColSize; ++j) {
// int c_offset = (j / BaseType::blockColSize()) *
// BaseType::blockRowSize() * BaseType::blockColSize() +
// (j % BaseType::blockColSize()) * row_interleave_;
Expand All @@ -292,12 +304,12 @@ void PackBMatrix<T, accT>::pack_unpack_(

if (ispack) {
pack_buf_cur[r_offset + c_offset] = tr
? unpack_buf[i + (g * block.col_size + j) * ld_]
: unpack_buf[(g * block.row_size + i) * ld_ + j];
? unpack_buf[i + (g * blockColSize + j) * ld_]
: unpack_buf[(g * blockRowSize + i) * ld_ + j];
} else {
T* unpack_buf_cur = tr
? &(unpack_buf[i + (g * block.col_size + j) * ld_])
: &(unpack_buf[(g * block.row_size + i) * ld_ + j]);
? &(unpack_buf[i + (g * blockColSize + j) * ld_])
: &(unpack_buf[(g * blockRowSize + i) * ld_ + j]);
*unpack_buf_cur = pack_buf_cur[r_offset + c_offset];
}

Expand All @@ -311,8 +323,8 @@ void PackBMatrix<T, accT>::pack_unpack_(
if (ispack) {
// fill the remaining with zero.
// Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
for (int i = block.row_start + block.row_size;
i < (block.row_start + block.row_size + row_interleave_ - 1) /
for (int i = blockRowStart + blockRowSize;
i < (blockRowStart + blockRowSize + row_interleave_ - 1) /
row_interleave_ * row_interleave_;
++i) {
int r_offset =
Expand All @@ -321,7 +333,7 @@ void PackBMatrix<T, accT>::pack_unpack_(
(i % BaseType::blockRowSize() / row_interleave_) *
BaseType::blockColSize() * row_interleave_ +
i % row_interleave_;
for (int j = block.col_start; j < block.col_start + block.col_size;
for (int j = blockColStart; j < blockColStart + blockColSize;
j++) {
int c_offset = (j / BaseType::blockColSize()) *
BaseType::blockRowSize() * BaseType::blockColSize() +
Expand Down

0 comments on commit dccc573

Please sign in to comment.