Skip to content

Commit

Permalink
Wextra pedantic fbgemm's (pytorch#642)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#642

Enable compilation flags to enforce a variety of safety measures

Reviewed By: jianyuh

Differential Revision: D29408318

fbshipit-source-id: 04ad39ae45fb9664aa1e05262f87b8f11e716e31
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jul 1, 2021
1 parent 5dc753c commit 31905d2
Show file tree
Hide file tree
Showing 16 changed files with 70 additions and 58 deletions.
13 changes: 7 additions & 6 deletions include/fbgemm/FbgemmFPCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void cblas_gemm_compute(
i_end = m;
for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
int mb = std::min(mb_max, i_end - m0);
assert(mb < partition.size());
assert(mb < static_cast<int64_t>(partition.size()));
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
float beta_;
Expand All @@ -128,13 +128,13 @@ void cblas_gemm_compute(

auto m1 = m0;
auto const num_cycles = partition[mb].size();
for (auto c = 0; c < num_cycles; ++c) {
for (size_t c = 0; c < num_cycles; ++c) {
auto kernel_nrows = partition[mb][c][0];
auto nkernel_nrows = partition[mb][c][1];
auto m_start = m1;
auto m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
assert(kernel_nrows * kb < scratchpad->size());
assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
if (m != 1) {
PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
gp.A = scratchpad->data();
Expand Down Expand Up @@ -190,14 +190,15 @@ void cblas_gemm_compute(
// use one thread to handle the fringe cases
if (thread_id == num_threads - 1) {
// leftover
int rem = n - last_blk_col;
const int rem = n - last_blk_col;
(void)rem; // Suppress unused variable warning
assert(rem < Bp.blockColSize());

// small temporary buffer: the size should be larger than the
// required kernel_nrow x kernel_ncols elements computed in the
// registers.
std::array<float, 14 * 32> c_tmp{0.f};
assert(c_tmp.size() >= kernel_nrows * Bp.blockColSize());
assert(static_cast<int64_t>(c_tmp.size()) >= kernel_nrows * Bp.blockColSize());

gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp.data();
Expand All @@ -213,7 +214,7 @@ void cblas_gemm_compute(
for (int j = last_blk_col; j < n; j++) {
assert(
i * Bp.blockColSize() + (j - last_blk_col) <
sizeof(c_tmp) / sizeof(c_tmp[0]));
static_cast<int64_t>(sizeof(c_tmp) / sizeof(c_tmp[0])));
if (beta_ == 0.f) {
C[(m2 + i) * ldc + j] =
c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
Expand Down
4 changes: 2 additions & 2 deletions include/fbgemm/FbgemmPackMatrixB.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ class PackedGemmMatrixB {
}

const T& operator()(const int r, const int c) const {
uint64_t a = addr(r, c);
const auto a = addr(r, c);
assert(r < numRows());
assert(c < numCols());
assert(a < this->matSize());
assert(static_cast<int64_t>(a) < this->matSize());
return pmat_[a];
}

Expand Down
2 changes: 1 addition & 1 deletion include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ FBGEMM_API int compare_buffers(
int m,
int n,
int ld,
int max_mismatches_to_report,
size_t max_mismatches_to_report,
float atol = 1e-3);

/**
Expand Down
1 change: 1 addition & 0 deletions src/FbgemmI64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
#endif

const int maxMRegs = mRegBlockSize;
(void)maxMRegs; // Suppress unused variable warning
const int maxNRegs = nRegBlockSize / vectorLen;
assert(
maxMRegs * maxNRegs <= 30 &&
Expand Down
1 change: 1 addition & 0 deletions src/FbgemmSparseDenseInt8Avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void SparseDenseInt8MMAvx2(
constexpr int VLEN_INT8 = 32;
constexpr int VLEN_INT32 = 8;
constexpr int rowBlockSize = BCSRMatrix<>::RB;
(void)rowBlockSize; // Suppress unused variable warning
constexpr int colBlockSize = BCSRMatrix<>::CB;

constexpr int colTileSize = BCSRMatrix<>::COLTILE;
Expand Down
6 changes: 4 additions & 2 deletions src/GenerateKernelU8S8S32ACC16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ getOrCreate<inst_set_t::avx2>(
assert(
kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
const int maxMRegs = mRegBlockSize;
const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
(void)maxMRegs; // Suppress unused variable warning
(void)maxNRegs; // Suppress unused variable warning
assert(
maxMRegs * maxNRegs <= 13 &&
"MR*(NR*ROW_INTERLEAVE*8/256"
Expand Down
5 changes: 3 additions & 2 deletions src/GenerateKernelU8S8S32ACC16Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate(
assert(
kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
const int maxMRegs = mRegBlockSize;
(void)maxMRegs; // Suppress unused variable warning
const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
assert(
(maxMRegs + 1) * maxNRegs <= 29 &&
"number of zmm registers for C + one row for loading B: \
Expand Down
5 changes: 3 additions & 2 deletions src/GenerateKernelU8S8S32ACC32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
assert(
kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
const int maxMRegs = mRegBlockSize;
(void)maxMRegs; // Suppress unused variable warning
const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
assert(
maxMRegs * maxNRegs <= numRegs - 4 &&
"MRegs x NRegs is above available registers (MAX_REGS - 4)");
Expand Down
5 changes: 3 additions & 2 deletions src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
assert(
kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
const int maxMRegs = mRegBlockSize;
const int maxNRegs = nRegBlockSize * row_interleave / vectorLen;
(void)maxMRegs; // Suppress unused variable warning
assert(
maxMRegs * maxNRegs <= 30 &&
"MR*(NR*ROW_INTERLEAVE*8/512) \
Expand Down
24 changes: 12 additions & 12 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false)
QuantizeAvx2<T, LEGACY>( \
&src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
} else { \
for (std::size_t i = i_begin; i < i_end; ++i) { \
for (int i = i_begin; i < i_end; ++i) { \
dst[i] = Quantize<T, LEGACY>(src[i], qparams); \
} \
} \
Expand Down Expand Up @@ -261,7 +261,7 @@ FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false)
FusedQuantizeDequantizeAvx2<T>( \
&src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
} else if (noise_ratio <= 0.0f) { \
for (std::size_t i = i_begin; i < i_end; ++i) { \
for (int i = i_begin; i < i_end; ++i) { \
dst[i] = FusedQuantizeDequantize<T>(src[i], qparams); \
} \
} else { \
Expand Down Expand Up @@ -510,7 +510,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(float16);
std::vector<float> input_row_float(input_columns);
for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const InputType* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float16* output_row_scale_bias = reinterpret_cast<float16*>(
Expand All @@ -519,7 +519,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(

// NOTE: this can be optimized, however we don't care much about performance
// for reference implementation.
for (std::size_t col = 0; col < input_columns; ++col) {
for (int col = 0; col < input_columns; ++col) {
if (std::is_same<InputType, float>()) {
input_row_float[col] = input_row[col];
} else {
Expand Down Expand Up @@ -553,7 +553,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(

output_row_scale_bias[0] = cpu_float2half_rn(scale);
output_row_scale_bias[1] = minimum_element_fp16;
for (std::size_t col = 0; col < input_columns; ++col) {
for (int col = 0; col < input_columns; ++col) {
float X = input_row_float[col];
std::uint8_t quantized = std::max(
0,
Expand Down Expand Up @@ -619,13 +619,13 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(

int output_columns = input_columns + 2 * sizeof(float);
std::vector<float> input_row_float(input_columns);
for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const InputType* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);

for (std::size_t col = 0; col < input_columns; ++col) {
for (int col = 0; col < input_columns; ++col) {
if (std::is_same<InputType, float>()) {
input_row_float[col] = input_row[col];
} else {
Expand All @@ -642,7 +642,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
for (std::size_t col = 0; col < input_columns; ++col) {
for (int col = 0; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row_float[col] - minimum_element) * inverse_scale);
}
Expand Down Expand Up @@ -678,7 +678,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int output_columns =
(input_columns - 2 * sizeof(float16)) * num_elem_per_byte;

for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float16* input_row_scale_bias = reinterpret_cast<const float16*>(
input_row +
Expand All @@ -687,7 +687,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
float bias = cpu_half2float(input_row_scale_bias[1]);
OutputType* output_row = output + row * output_columns;

for (std::size_t col = 0; col < output_columns; ++col) {
for (int col = 0; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
Expand Down Expand Up @@ -740,13 +740,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
OutputType* output) {
int output_columns = input_columns - 2 * sizeof(float);

for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
OutputType* output_row = output + row * output_columns;

for (std::size_t col = 0; col < output_columns; ++col) {
for (int col = 0; col < output_columns; ++col) {
float output_value =
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
if (std::is_same<OutputType, float>()) {
Expand Down
23 changes: 12 additions & 11 deletions src/QuantUtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void QuantizeAvx2(
// that is exactly representable in float
constexpr int32_t int32_float_max_val =
std::numeric_limits<int32_t>::max() - 127;
std::size_t i = 0;
int i = 0;
float inverse_scale = 1.f / qparams.scale;
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
// clang-format off
Expand Down Expand Up @@ -170,7 +170,7 @@ void NO_SANITIZE("address") FusedQuantizeDequantizeAvx2(
// that is exactly representable in float
constexpr int32_t int32_float_max_val =
std::numeric_limits<int32_t>::max() - 127;
std::size_t i = 0;
int i = 0;
uint32_t rand;
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
__m256 scale_v = _mm256_set1_ps(qparams.scale);
Expand Down Expand Up @@ -1356,7 +1356,8 @@ void requantizeOutputProcessingGConvAvx2(
_mm256_castsi256_si128(x_clamped_v));
} // j loop vectorized

int remainder = block.col_start + block.col_size - j;
const int remainder = block.col_start + block.col_size - j;
(void)remainder; // Suppress unused variable warning
assert(remainder == 0);
} // i loop
}
Expand Down Expand Up @@ -1505,7 +1506,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
}

for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const InputType* input_row = input + row * input_columns;
const float* input_row_float;
if (std::is_same<InputType, float>()) {
Expand All @@ -1527,7 +1528,7 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
__m256 min_v = _mm256_set1_ps(minimum_element);
__m256 max_v = _mm256_set1_ps(maximum_element);

std::size_t col;
int col;
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
__m256 in_v;
if (std::is_same<InputType, float>()) {
Expand Down Expand Up @@ -1707,7 +1708,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
input_row_float_for_fp16 = static_cast<float*>(
fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
}
for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const InputType* input_row = input + row * input_columns;
const float* input_row_float;
if (std::is_same<InputType, float>()) {
Expand All @@ -1726,7 +1727,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
float maximum_element = -FLT_MAX;
__m256 min_v = _mm256_set1_ps(minimum_element);
__m256 max_v = _mm256_set1_ps(maximum_element);
std::size_t col;
int col;
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
__m256 in_v;
if (std::is_same<InputType, float>()) {
Expand Down Expand Up @@ -1888,7 +1889,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
(VLEN + 1))));
}

for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const uint16_t* input_row_scale_bias = reinterpret_cast<const uint16_t*>(
input_row +
Expand All @@ -1904,7 +1905,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
output_row_float = reinterpret_cast<float*>(output_row);
}

std::size_t col = 0;
int col = 0;
if (BIT_RATE == 4 || BIT_RATE == 2) {
__m256 vscale = _mm256_set1_ps(scale);
__m256 vbias = _mm256_set1_ps(bias);
Expand Down Expand Up @@ -2060,7 +2061,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
constexpr int VLEN = 8;
int output_columns = input_columns - 2 * sizeof(float);

for (std::size_t row = 0; row < input_rows; ++row) {
for (int row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
Expand All @@ -2069,7 +2070,7 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
__m256 scale_v = _mm256_set1_ps(input_row_scale_bias[0]);
__m256 bias_v = _mm256_set1_ps(input_row_scale_bias[1]);

std::size_t col;
int col;
for (col = 0; col < output_columns / VLEN * VLEN; col += VLEN) {
__m256 in_v = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(input_row + col))));
Expand Down
6 changes: 3 additions & 3 deletions src/RowWiseSparseAdagradFused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ typename ReturnFunctionSignature<indxType, offsetType, dataType>::
x86::rsp,
x86::ptr(
x86::rsp, static_cast<int>(-vlen * sizeof(float16))));
for (size_t r = 0; r < remainder; ++r) {
for (int r = 0; r < remainder; ++r) {
a->mov(
h.r16(),
x86::word_ptr(
Expand All @@ -652,7 +652,7 @@ typename ReturnFunctionSignature<indxType, offsetType, dataType>::
// Truncate rounding to 'counterwork' the random added part
a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 11);
// Copy results back
for (size_t r = 0; r < remainder; ++r) {
for (int r = 0; r < remainder; ++r) {
a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r));
a->mov(
x86::word_ptr(
Expand Down Expand Up @@ -788,7 +788,7 @@ void rand_initialize() {
for (auto i = 0; i < 4; ++i) {
g_rnd128v_buffer[i * VLEN_MAX] = rnd128_init_next(h0);
uint64_t h1 = g_rnd128v_buffer[i * VLEN_MAX];
for (auto v = 1; v < VLEN_MAX; ++v) {
for (size_t v = 1; v < VLEN_MAX; ++v) {
g_rnd128v_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/SparseAdagrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ int SparseAdaGradBlockSize1_(
if (weight_decay != 0.0f) {
for (int i = 0; i < num_rows; ++i) {
IndexType idx = indices[i];
if (idx >= param_size) {
if (idx >= static_cast<int64_t>(param_size)) {
return i;
}

Expand All @@ -821,7 +821,7 @@ int SparseAdaGradBlockSize1_(
} else {
for (int i = 0; i < num_rows; ++i) {
IndexType idx = indices[i];
if (idx >= param_size) {
if (idx >= static_cast<int64_t>(param_size)) {
return i;
}
float gi = g[i];
Expand Down
Loading

0 comments on commit 31905d2

Please sign in to comment.