Skip to content

Commit

Permalink
move avx2 masks to MaskAvx2.h and reuse (pytorch#141)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#141

Refactor to dedup mask arrays for avx2

Reviewed By: jianyuh

Differential Revision: D17916885

fbshipit-source-id: f3d48f16411a45307b44484fe39b472a08dd6c62
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 15, 2019
1 parent 266b453 commit f40e2d0
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 101 deletions.
47 changes: 24 additions & 23 deletions src/FbgemmI8Depthwise2DAvx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tuple> // for tie

#include "src/FbgemmI8DepthwiseAvx2-inl.h"
#include "src/MaskAvx2.h"

namespace fbgemm {

Expand Down Expand Up @@ -44,8 +45,8 @@ static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
_mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(masks[remainder / 4]));
mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder / 4]));
}

// The code below can be written as a simple R*S loop but the compiler
Expand Down Expand Up @@ -156,8 +157,8 @@ static inline __attribute__((always_inline)) void inner_prod_5x5_packed_(
_mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(masks[remainder / 4]));
mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder / 4]));
}

// The code below can be written as a simple R*S loop but the compiler
Expand Down Expand Up @@ -1600,24 +1601,24 @@ FBGEMM_API void depthwise_3x3_pad_1(

template <typename BIAS_TYPE = std::int32_t>
FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
int W,
int K,
int stride_h,
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
const BIAS_TYPE* bias,
bool fuse_relu = false,
const float* act_times_w_scale = nullptr,
int thread_id = 0,
int num_threads = 1);
int N,
int H,
int W,
int K,
int stride_h,
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
const std::int32_t* col_offsets,
const BIAS_TYPE* bias,
bool fuse_relu = false,
const float* act_times_w_scale = nullptr,
int thread_id = 0,
int num_threads = 1);

} // namespace fbgemm
5 changes: 3 additions & 2 deletions src/FbgemmI8Depthwise3DAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tuple> // for tie

#include "FbgemmI8DepthwiseAvx2-inl.h"
#include "MaskAvx2.h"

using namespace std;

Expand Down Expand Up @@ -37,8 +38,8 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
__m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(masks[remainder / 4]));
mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder / 4]));
}

// The code below can be written as a simple R*S loop but the compiler
Expand Down
15 changes: 0 additions & 15 deletions src/FbgemmI8DepthwiseAvx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@

namespace fbgemm {

// clang-format off
static int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
{ -1, -1, 0, 0, 0, 0, 0, 0, },
{ -1, -1, -1, 0, 0, 0, 0, 0, },
{ -1, -1, -1, -1, 0, 0, 0, 0, },
{ -1, -1, -1, -1, -1, 0, 0, 0, },
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
// clang-format on

// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
// A is in uint8_t
// B is in int8_t and pre-interleaved
Expand Down
32 changes: 32 additions & 0 deletions src/MaskAvx2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

namespace fbgemm {

namespace internal {

// A constant array to initialize an AVX2 register to be used as a 32-bit
// granularity mask.
// clang-format off
alignas(64) static const int avx2_ps_or_epi32_masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
{ -1, -1, 0, 0, 0, 0, 0, 0, },
{ -1, -1, -1, 0, 0, 0, 0, 0, },
{ -1, -1, -1, -1, 0, 0, 0, 0, },
{ -1, -1, -1, -1, -1, 0, 0, 0, },
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
// clang-format on

} // namespace internal

} // namespace fbgemm
21 changes: 4 additions & 17 deletions src/PackDepthwiseConvMatrixAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,12 @@

#include <immintrin.h>

#include "MaskAvx2.h"

using namespace std;

namespace fbgemm {

// clang-format off
static int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
{ -1, -1, 0, 0, 0, 0, 0, 0, },
{ -1, -1, -1, 0, 0, 0, 0, 0, },
{ -1, -1, -1, -1, 0, 0, 0, 0, },
{ -1, -1, -1, -1, -1, 0, 0, 0, },
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
// clang-format on

PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
int K,
int kernel_prod,
Expand Down Expand Up @@ -105,8 +92,8 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
__m256i b_v[kernel_prod];
int remainder = K - k1;
if (remainder < 32) {
__m256i mask_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(masks[remainder / 4]));
__m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder / 4]));
for (int i = 0; i < kernel_prod; ++i) {
b_v[i] = _mm256_maskload_epi32(
reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
Expand Down
56 changes: 12 additions & 44 deletions src/QuantUtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <algorithm> //for std::min/std::max
#include <cmath> //for nearbyint
#include <limits> //for numeric_limits
#include "MaskAvx2.h"
#include "fbgemm/Fbgemm.h" //for ReQuantizeOutput

namespace fbgemm {
Expand Down Expand Up @@ -99,20 +100,17 @@ void QuantizeAvx2(
}

// Instantiate QuantizeAvx2 for known datatypes
template
void QuantizeAvx2<uint8_t>(
template void QuantizeAvx2<uint8_t>(
const float* src,
uint8_t* dst,
int len,
const TensorQuantizationParams& qparams);
template
void QuantizeAvx2<int8_t>(
template void QuantizeAvx2<int8_t>(
const float* src,
int8_t* dst,
int len,
const TensorQuantizationParams& qparams);


void FindMinMax(const float* a, float* min, float* max, int len) {
if (len <= 0) {
*min = 0.0f;
Expand Down Expand Up @@ -160,7 +158,7 @@ void RequantizeAvx2(
int len,
const RequantizationParams& params) {
DoNothing<> doNothingObj{};
int32_t Bq_zero_point[] = { 0 };
int32_t Bq_zero_point[] = {0};
ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj(
doNothingObj,
&params.real_multiplier,
Expand Down Expand Up @@ -670,29 +668,14 @@ void requantizeOutputProcessingAvx2(

int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
// clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
{ -1, -1, 0, 0, 0, 0, 0, 0, },
{ -1, -1, -1, 0, 0, 0, 0, 0, },
{ -1, -1, -1, -1, 0, 0, 0, 0, },
{ -1, -1, -1, -1, -1, 0, 0, 0, },
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
// clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
__m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder]));

__m256i x_v = _mm256_maskload_epi32(
inp + (i - block.row_start) * ld_in + (j - block.col_start),
mask_v);
inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);

if (!A_SYMMETRIC) {
__m256i col_off_v = _mm256_mullo_epi32(
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
x_v = _mm256_sub_epi32(x_v, col_off_v);
}
Expand Down Expand Up @@ -880,29 +863,14 @@ void requantizeForFloatAvx2(

int remainder = block.col_start + block.col_size - j;
if (remainder > 0) {
// clang-format off
alignas(64) const int masks[8][8] = {
// NOTE: clang-format wants to use a different formatting but the
// current formatting should be easier to read.
{ 0, 0, 0, 0, 0, 0, 0, 0, },
{ -1, 0, 0, 0, 0, 0, 0, 0, },
{ -1, -1, 0, 0, 0, 0, 0, 0, },
{ -1, -1, -1, 0, 0, 0, 0, 0, },
{ -1, -1, -1, -1, 0, 0, 0, 0, },
{ -1, -1, -1, -1, -1, 0, 0, 0, },
{ -1, -1, -1, -1, -1, -1, 0, 0, },
{ -1, -1, -1, -1, -1, -1, -1, 0, },
};
// clang-format on
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(masks[remainder]));
__m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
internal::avx2_ps_or_epi32_masks[remainder]));

__m256i x_v = _mm256_maskload_epi32(
inp + (i - block.row_start) * ld_in + (j - block.col_start),
mask_v);
inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);

if (!A_SYMMETRIC) {
__m256i col_off_v = _mm256_mullo_epi32(
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
x_v = _mm256_sub_epi32(x_v, col_off_v);
}
Expand Down

0 comments on commit f40e2d0

Please sign in to comment.