diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 6dc7a2a767..d2c2a1ee70 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -373,6 +373,16 @@ class PackBMatrix final : public PackMatrix, T, accT> { PackBMatrix() = delete; // no default constructor + /** + * @params groups if > 1 and trans == NoTranspose, smat is nRow x nCol with + * groups are vertically concatenated: each group is + * (nRow / groups) x nCol . + * if > 1 and trans == Transpose, smat is (nCol * groups) x + * (nRow / groups) with groups are horizontally concatenated: + * each group is nCol x (nRow / groups) . Each group is + * transposed and vertically concatenated to match with the + * NoTranspose case. + */ PackBMatrix( matrix_op_t trans, std::int32_t nRow, diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 2cb99db035..1b3899771d 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -69,7 +69,7 @@ void PackBMatrix::pack(const block_type_t& block) { g * this->packedBufferSize(block.row_size, block.col_size); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { - T val = tr ? smat_[g * block.row_size + i + ld_ * j] + T val = tr ? smat_[i + (g * block.col_size + j) * ld_] : smat_[(g * block.row_size + i) * ld_ + j]; out[addr(i, j)] = tconv(val, out[addr(i, j)]); } diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index cc77f8bde8..cb614cdd65 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -118,8 +118,8 @@ TEST_P(fbgemmu8s8acc16test, Test) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -197,7 +197,7 @@ TEST_P(fbgemmu8s8acc16test, Test) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups); @@ -376,8 +376,8 @@ TEST_P(fbgemmu8s8acc16test, SpMDMTest) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -442,7 +442,7 @@ TEST_P(fbgemmu8s8acc16test, SpMDMTest) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups); @@ -566,8 +566,8 @@ TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -628,7 +628,7 @@ TEST_P(fbgemmu8s8acc16test, NoRequantizeTest) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups); diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc index 267a76403e..a5744c09f5 100644 --- a/test/PackedRequantizeTest.cc +++ b/test/PackedRequantizeTest.cc @@ -140,8 +140,8 @@ TEST_P(fbgemmu8s8acc32test, Test) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -217,7 +217,7 @@ TEST_P(fbgemmu8s8acc32test, Test) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups); @@ -372,8 +372,8 @@ TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -396,7 +396,7 @@ TEST_P(fbgemmu8s8acc32test, TestFloatInputOutput) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups); @@ -537,8 +537,8 @@ TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) { n, Bint8.data() + g * k_per_group * n, n, - Bint8_temp.data() + g * k_per_group, - groups * k_per_group); + Bint8_temp.data() + g * k_per_group * n, + k_per_group); } Bint8 = Bint8_temp; } @@ -562,7 +562,7 @@ TEST_P(fbgemmu8s8acc32test, TestSymmetricQuantizedInputOutput) { k, n_adjusted, Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k : n, + (btrans == matrix_op_t::Transpose) ? k_per_group : n, nullptr, groups);