Skip to content

Commit

Permalink
fix group convention in B packing (pytorch#26)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#26

Set convention that group is the leading (slowest moving) dimension of B.

Reviewed By: dskhudia

Differential Revision: D13176477

fbshipit-source-id: 64d5f168434e7fa0f90b46b0a8559569804c844b
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Nov 27, 2018
1 parent ea47a69 commit db52c82
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
10 changes: 10 additions & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,16 @@ class PackBMatrix final : public PackMatrix<PackBMatrix<T, accT>, T, accT> {

PackBMatrix() = delete; // no default constructor

/**
* @params groups if > 1 and trans == NoTranspose, smat is nRow x nCol with
* groups are vertically concatenated: each group is
* (nRow / groups) x nCol .
* if > 1 and trans == Transpose, smat is (nCol * groups) x
* (nRow / groups) with groups are horizontally concatenated:
* each group is nCol x (nRow / groups) . Each group is
* transposed and vertically concatenated to match with the
* NoTranspose case.
*/
PackBMatrix(
matrix_op_t trans,
std::int32_t nRow,
Expand Down
2 changes: 1 addition & 1 deletion src/PackBMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) {
g * this->packedBufferSize(block.row_size, block.col_size);
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
T val = tr ? smat_[g * block.row_size + i + ld_ * j]
T val = tr ? smat_[i + (g * block.col_size + j) * ld_]
: smat_[(g * block.row_size + i) * ld_ + j];
out[addr(i, j)] = tconv(val, out[addr(i, j)]);
}
Expand Down
18 changes: 9 additions & 9 deletions test/PackedRequantizeAcc16Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ TEST_P(fbgemmu8s8acc16test, Test) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand Down Expand Up @@ -197,7 +197,7 @@ TEST_P(fbgemmu8s8acc16test, Test) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down Expand Up @@ -376,8 +376,8 @@ TEST_P(fbgemmu8s8acc16test, SpMDMTest) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand Down Expand Up @@ -442,7 +442,7 @@ TEST_P(fbgemmu8s8acc16test, SpMDMTest) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down Expand Up @@ -566,8 +566,8 @@ TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand Down Expand Up @@ -628,7 +628,7 @@ TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down
18 changes: 9 additions & 9 deletions test/PackedRequantizeTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ TEST_P(fbgemmu8s8acc32test, Test) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand Down Expand Up @@ -217,7 +217,7 @@ TEST_P(fbgemmu8s8acc32test, Test) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down Expand Up @@ -372,8 +372,8 @@ TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand All @@ -396,7 +396,7 @@ TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down Expand Up @@ -537,8 +537,8 @@ TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) {
n,
Bint8.data() + g * k_per_group * n,
n,
Bint8_temp.data() + g * k_per_group,
groups * k_per_group);
Bint8_temp.data() + g * k_per_group * n,
k_per_group);
}
Bint8 = Bint8_temp;
}
Expand All @@ -562,7 +562,7 @@ TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) {
k,
n_adjusted,
Bint8.data(),
(btrans == matrix_op_t::Transpose) ? k : n,
(btrans == matrix_op_t::Transpose) ? k_per_group : n,
nullptr,
groups);

Expand Down

0 comments on commit db52c82

Please sign in to comment.