Skip to content

Commit

Permalink
skip acc16 test when vnni is availble (pytorch#183)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#183

Tests were failing if I run with sde and -clx mode

Reviewed By: jianyuh

Differential Revision: D18471607

fbshipit-source-id: b80aaaad111fd33cc068fbfc3f2d8abb37253c6d
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Nov 14, 2019
1 parent 4e8cee1 commit ac90a86
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
12 changes: 10 additions & 2 deletions src/ExecuteKernelU8S8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,16 @@ void ExecuteKernel<
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
if (fbgemmHasAvx512VnniSupport()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
if (std::is_same<typename packingAMatrix::accType, std::int16_t>::
value) {
// For AVX512VNNI, we redirect int16_t to int32_t accumulation.
CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
}
} else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
Expand Down
77 changes: 49 additions & 28 deletions test/PackedRequantizeAcc16Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cpuinfo.h>
#include <algorithm>
#include <chrono>
#include <cmath>
Expand Down Expand Up @@ -69,33 +70,35 @@ INSTANTIATE_TEST_CASE_P(
* @brief Shapes for unit test.
*/
static vector<vector<int>> GetShapes_() {
// clang-format off
// NMT
vector<vector<int>> shapes = {
// {M, N, K}
{1, 128, 512},
{1, 1024, 256},
{1, 2048, 512},
{1, 2048, 513},
{1, 2048, 514},

{6, 512, 512},
{6, 2048, 512},
{6, 256, 1024},
{6, 1024, 256},
{6, 2048, 256},
{6, 2048, 257},
{6, 2048, 258},

{102, 1024, 512},
{102, 2323, 256},
{102, 512, 256},
{102, 512, 257},
{102, 512, 258},

{1024, 512, 258},

{120, 4, 288},
// {M, N, K}
{1, 128, 512},
{1, 1024, 256},
{1, 2048, 512},
{1, 2048, 513},
{1, 2048, 514},

{6, 512, 512},
{6, 2048, 512},
{6, 256, 1024},
{6, 1024, 256},
{6, 2048, 256},
{6, 2048, 257},
{6, 2048, 258},

{102, 1024, 512},
{102, 2323, 256},
{102, 512, 256},
{102, 512, 257},
{102, 512, 258},

{1024, 512, 258},

{120, 4, 288},
};
// clang-format on
return shapes;
}

Expand All @@ -104,6 +107,12 @@ static vector<vector<int>> GetShapes_() {
* accumulation. Output processing: requantization -> nothing
*/
TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, Test) {
cpuinfo_initialize();
if (fbgemmHasAvx512VnniSupport()) {
// No need to use acc16 if VNNI is available
return;
}

vector<vector<int>> shapes(GetShapes_());
matrix_op_t atrans, btrans;
bool test_ld;
Expand Down Expand Up @@ -345,6 +354,12 @@ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, Test) {
* accumulation. Output processing: spmdm -> requantization -> nothing
*/
TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, SpMDMTest) {
cpuinfo_initialize();
if (fbgemmHasAvx512VnniSupport()) {
// No need to use acc16 if VNNI is available
return;
}

vector<vector<int>> shapes(GetShapes_());
matrix_op_t atrans, btrans;
bool test_ld;
Expand Down Expand Up @@ -671,6 +686,12 @@ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, SpMDMTest) {
* accumulation. Output processing: nothing
*/
TEST_P(fbgemmu8s8acc16Test, NoRequantizeTest) {
cpuinfo_initialize();
if (fbgemmHasAvx512VnniSupport()) {
// No need to use acc16 if VNNI is available
return;
}

vector<vector<int>> shapes(GetShapes_());
matrix_op_t atrans, btrans;
bool test_ld;
Expand Down Expand Up @@ -884,10 +905,10 @@ TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) {
for (int i = 0; i < k; i++) {
for (int j = 0; j < n_adjusted; j++) {
EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
<< "Pack/Unpack results differ at index (" << i << ", " << j
<< ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
<< ", Pack-Unpacked: "
<< static_cast<int>(unpack_buf.data()[i * n + j]);
<< "Pack/Unpack results differ at index (" << i << ", " << j
<< ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
<< ", Pack-Unpacked: "
<< static_cast<int>(unpack_buf.data()[i * n + j]);
}
}
}
Expand Down

0 comments on commit ac90a86

Please sign in to comment.