diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 4ae1b50fef..91736b0b58 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -181,8 +181,16 @@ void ExecuteKernel< int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { if (fbgemmHasAvx512VnniSupport()) { - fn = BaseType::template getOrCreate( - accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + if (std::is_same:: + value) { + // For AVX512VNNI, we redirect int16_t to int32_t accumulation. + CodeGenBase codeObj; + fn = codeObj.getOrCreate( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else { + fn = BaseType::template getOrCreate( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 93e75661ac..a7507717df 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -4,6 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -69,33 +70,35 @@ INSTANTIATE_TEST_CASE_P( * @brief Shapes for unit test. */ static vector> GetShapes_() { + // clang-format off // NMT vector> shapes = { - // {M, N, K} - {1, 128, 512}, - {1, 1024, 256}, - {1, 2048, 512}, - {1, 2048, 513}, - {1, 2048, 514}, - - {6, 512, 512}, - {6, 2048, 512}, - {6, 256, 1024}, - {6, 1024, 256}, - {6, 2048, 256}, - {6, 2048, 257}, - {6, 2048, 258}, - - {102, 1024, 512}, - {102, 2323, 256}, - {102, 512, 256}, - {102, 512, 257}, - {102, 512, 258}, - - {1024, 512, 258}, - - {120, 4, 288}, + // {M, N, K} + {1, 128, 512}, + {1, 1024, 256}, + {1, 2048, 512}, + {1, 2048, 513}, + {1, 2048, 514}, + + {6, 512, 512}, + {6, 2048, 512}, + {6, 256, 1024}, + {6, 1024, 256}, + {6, 2048, 256}, + {6, 2048, 257}, + {6, 2048, 258}, + + {102, 1024, 512}, + {102, 2323, 256}, + {102, 512, 256}, + {102, 512, 257}, + {102, 512, 258}, + + {1024, 512, 258}, + + {120, 4, 288}, }; + // clang-format on return shapes; } @@ -104,6 +107,12 @@ static vector> GetShapes_() { * accumulation. Output processing: requantization -> nothing */ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, Test) { + cpuinfo_initialize(); + if (fbgemmHasAvx512VnniSupport()) { + // No need to use acc16 if VNNI is available + return; + } + vector> shapes(GetShapes_()); matrix_op_t atrans, btrans; bool test_ld; @@ -345,6 +354,12 @@ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, Test) { * accumulation. Output processing: spmdm -> requantization -> nothing */ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, SpMDMTest) { + cpuinfo_initialize(); + if (fbgemmHasAvx512VnniSupport()) { + // No need to use acc16 if VNNI is available + return; + } + vector> shapes(GetShapes_()); matrix_op_t atrans, btrans; bool test_ld; @@ -671,6 +686,12 @@ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, SpMDMTest) { * accumulation. Output processing: nothing */ TEST_P(fbgemmu8s8acc16Test, NoRequantizeTest) { + cpuinfo_initialize(); + if (fbgemmHasAvx512VnniSupport()) { + // No need to use acc16 if VNNI is available + return; + } + vector> shapes(GetShapes_()); matrix_op_t atrans, btrans; bool test_ld; @@ -884,10 +905,10 @@ TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) { for (int i = 0; i < k; i++) { for (int j = 0; j < n_adjusted; j++) { EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) - << "Pack/Unpack results differ at index (" << i << ", " << j - << ", Reference: " << static_cast(Bint8.data()[i * n + j]) - << ", Pack-Unpacked: " - << static_cast(unpack_buf.data()[i * n + j]); + << "Pack/Unpack results differ at index (" << i << ", " << j + << ", Reference: " << static_cast(Bint8.data()[i * n + j]) + << ", Pack-Unpacked: " + << static_cast(unpack_buf.data()[i * n + j]); } } }