Skip to content

Commit

Permalink
use int64_t in fbgemmPartition1D (pytorch#933)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#933

To handle large tensors.

Reviewed By: jianyuh

Differential Revision: D34234352

fbshipit-source-id: bdfc9eb43efb63ac8582ef0c56a3d6fe7cd01da3
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 15, 2022
1 parent 5ae23f9 commit f7050c9
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 70 deletions.
4 changes: 2 additions & 2 deletions include/fbgemm/FbgemmFPCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void cblas_gemm_compute(
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);

if ((n % Bp.blockColSize()) == 0) {
int jb_begin, jb_end;
int64_t jb_begin, jb_end;
fbgemmPartition1D(
thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
Expand All @@ -173,7 +173,7 @@ void cblas_gemm_compute(
} else {
int last_blk_col = nbcol * Bp.blockColSize();
if (nbcol) {
int jb_begin, jb_end;
int64_t jb_begin, jb_end;
fbgemmPartition1D(
thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
Expand Down
14 changes: 7 additions & 7 deletions include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ template <typename T, bool LEGACY = true>
FBGEMM_API void Quantize(
const float* src,
T* dst,
int len,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1);
Expand Down Expand Up @@ -144,13 +144,13 @@ template <typename T>
void Dequantize(
const T* src,
float* dst,
int len,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1) {
int i_begin, i_end;
int64_t i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
for (auto i = i_begin; i < i_end; i++) {
for (int64_t i = i_begin; i < i_end; i++) {
dst[i] = Dequantize(src[i], qparams);
}
}
Expand All @@ -173,7 +173,7 @@ template <typename T>
FBGEMM_API void FusedQuantizeDequantize(
const float* src,
float* dst,
int len,
std::int64_t len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1,
Expand Down Expand Up @@ -215,7 +215,7 @@ template <typename T>
FBGEMM_API void RequantizeFixedPoint(
const std::int32_t* src,
T* dst,
int len,
std::int64_t len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
Expand Down Expand Up @@ -249,7 +249,7 @@ template <typename T>
FBGEMM_API void Requantize(
const std::int32_t* src,
T* dst,
int len,
std::int64_t len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
Expand Down
12 changes: 6 additions & 6 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ bool isValidBlockingFactor(BlockingFactors* param) {
FBGEMM_API void fbgemmPartition1D(
int thread_id,
int num_threads,
int total_work,
int& start,
int& end);
std::int64_t total_work,
std::int64_t& start,
std::int64_t& end);

/**
* @brief Partition work across given number of threads in blocks
Expand Down Expand Up @@ -392,8 +392,8 @@ FBGEMM_API void fbgemmPartition1D(
FBGEMM_API void fbgemmPartition1DBlocked(
int thread_id,
int num_threads,
int total_work,
std::int64_t total_work,
int block_size,
int& start,
int& end);
std::int64_t& start,
std::int64_t& end);
} // namespace fbgemm
4 changes: 2 additions & 2 deletions src/ExecuteKernelU8S8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void ExecuteKernel<
bool lastKBlock = packedB_.isThisLastKBlock(kBlock % packedB_.blockRows());
bool accum = (kBlock % packedB_.blockRows()) > 0;

int jb_begin, jb_end;
int64_t jb_begin, jb_end;
fbgemmPartition1D(
th_info_.n_thread_id,
th_info_.n_num_threads,
Expand Down Expand Up @@ -329,7 +329,7 @@ void ExecuteKernel<
C_buffer_row_start + jb_begin * nbSize_,
{row_start_A,
packed_rows_A,
NDim * group + jb_begin * nbSize_,
static_cast<int>(NDim * group + jb_begin * nbSize_),
nSize},
ldc_,
ldc_);
Expand Down
4 changes: 2 additions & 2 deletions src/Fbgemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void fbgemmPacked(
throw std::runtime_error("unknown architecure");
}

int MCB;
int64_t MCB;
int KCB;
int MR;

Expand Down Expand Up @@ -144,7 +144,7 @@ void fbgemmPacked(
// if (thread_id == 0)
// std::cout << ", " << th_info.toString();

int g_begin, g_end, i_begin, i_end;
int64_t g_begin, g_end, i_begin, i_end;

// Calculate the begin and end index along the group dimension
fbgemmPartition1D(
Expand Down
6 changes: 3 additions & 3 deletions src/FbgemmI8Depthwise2DAvx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,16 @@ static ALWAYS_INLINE void depthwise_2d_(
int num_threads) {
assert(IC % 8 == 0);
constexpr int R = S;
constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
PAD_R = PAD_L;
constexpr int64_t PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
PAD_R = PAD_L;
int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const std::int8_t* Bp = B.PackedMat();

int32_t* row_offsets = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));

int n_begin, n_end, h_begin, h_end, w_begin, w_end;
int64_t n_begin, n_end, h_begin, h_end, w_begin, w_end;
// Reuse the 3-dim partition scheme for parallelization in matrix
// multiplication.
thread_type_t th_info =
Expand Down
6 changes: 3 additions & 3 deletions src/FbgemmI8Depthwise3DAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ static ALWAYS_INLINE void depthwise_3d_same_pad_(
int K_T = F[0], K_H = F[1], K_W = F[2];
int PAD_P = (F[0] - 1) / 2, PAD_N = PAD_P, PAD_T = (F[1] - 1) / 2,
PAD_B = PAD_T, PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L;
int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
int64_t T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
int64_t H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();

int32_t* row_offsets = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));

int n_begin, n_end, t_begin, t_end, h_begin, h_end;
int64_t n_begin, n_end, t_begin, t_end, h_begin, h_end;
// Reuse the 3-dim partition scheme for parallelization in matrix
// multiplication.
thread_type_t th_info =
Expand Down
24 changes: 12 additions & 12 deletions src/GroupwiseConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,10 @@ void fbgemmGroupwiseConv(
}
if (SPATIAL_DIM == 2) {
// Parallelization:
int batch_start = 0;
int batch_end = MB;
int oh_start = 0;
int oh_end = OH;
int64_t batch_start = 0;
int64_t batch_end = MB;
int64_t oh_start = 0;
int64_t oh_end = OH;
if (MB >= num_threads) {
fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
} else {
Expand Down Expand Up @@ -1100,8 +1100,8 @@ void fbgemmGroupwiseConv(

const int32_t* inp = out_start_group;
block_type_t block{
i * OT_OH_OW + oh_start * OW,
(oh_end - oh_start) * OW,
static_cast<int>(i * OT_OH_OW + oh_start * OW),
static_cast<int>((oh_end - oh_start) * OW),
g * K_per_G,
G_together * K_per_G};
int ld_out = G * K_per_G;
Expand Down Expand Up @@ -1139,10 +1139,10 @@ void fbgemmGroupwiseConv(
conv_param.pad[5]});

// Parallelization:
int batch_start = 0;
int batch_end = MB;
int oh_start = 0;
int oh_end = OH;
int64_t batch_start = 0;
int64_t batch_end = MB;
int64_t oh_start = 0;
int64_t oh_end = OH;
if (MB >= num_threads) {
fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
} else {
Expand Down Expand Up @@ -1243,8 +1243,8 @@ void fbgemmGroupwiseConv(

const int32_t* inp = out_start_t;
block_type_t block{
i * OT_OH_OW + oh_start * OW,
(oh_end - oh_start) * OW,
static_cast<int>(i * OT_OH_OW + oh_start * OW),
static_cast<int>((oh_end - oh_start) * OW),
g * K_per_G,
G_together * K_per_G};
int ld_out = G * K_per_G;
Expand Down
48 changes: 24 additions & 24 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ void ChooseRequantizationMultiplier(
FBGEMM_API void Quantize<T, LEGACY>( \
const float* src, \
T* dst, \
const int len, \
const int64_t len, \
const TensorQuantizationParams& qparams, \
int thread_id, \
int num_threads) { \
int i_begin, i_end; \
int64_t i_begin, i_end; \
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
for (int i = i_begin; i < i_end; ++i) { \
for (int64_t i = i_begin; i < i_end; ++i) { \
dst[i] = Quantize<T, LEGACY>(src[i], qparams); \
} \
}
Expand All @@ -217,20 +217,20 @@ FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false)
FBGEMM_API void Quantize<T, LEGACY>( \
const float* src, \
T* dst, \
int len, \
int64_t len, \
const TensorQuantizationParams& qparams, \
int thread_id, \
int num_threads) { \
bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
bool fma_support = cpuinfo_has_x86_fma3(); \
int i_begin, i_end; \
int64_t i_begin, i_end; \
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
if (avx2_support && fma_support && qparams.precision == 8) { \
/* fast path */ \
QuantizeAvx2<T, LEGACY>( \
&src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
} else { \
for (int i = i_begin; i < i_end; ++i) { \
for (int64_t i = i_begin; i < i_end; ++i) { \
dst[i] = Quantize<T, LEGACY>(src[i], qparams); \
} \
} \
Expand All @@ -247,21 +247,21 @@ FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false)
FBGEMM_API void FusedQuantizeDequantize<T>( \
const float* src, \
float* dst, \
int len, \
int64_t len, \
const TensorQuantizationParams& qparams, \
int thread_id, \
int num_threads, \
float noise_ratio) { \
bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
bool fma_support = cpuinfo_has_x86_fma3(); \
int i_begin, i_end; \
int64_t i_begin, i_end; \
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
if (avx2_support && fma_support && qparams.precision == 8) { \
/* fast path */ \
FusedQuantizeDequantizeAvx2<T>( \
&src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
} else if (noise_ratio <= 0.0f) { \
for (int i = i_begin; i < i_end; ++i) { \
for (int64_t i = i_begin; i < i_end; ++i) { \
dst[i] = FusedQuantizeDequantize<T>(src[i], qparams); \
} \
} else { \
Expand Down Expand Up @@ -399,13 +399,13 @@ int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) {
FBGEMM_API void Requantize<T>( \
const int32_t* src, \
T* dst, \
const int len, \
const int64_t len, \
const RequantizationParams& params, \
int thread_id, \
int num_threads) { \
int i_begin, i_end; \
int64_t i_begin, i_end; \
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
for (int i = i_begin; i < i_end; ++i) { \
for (int64_t i = i_begin; i < i_end; ++i) { \
dst[i] = Requantize<T>(src[i], params); \
} \
}
Expand All @@ -417,17 +417,17 @@ template <>
FBGEMM_API void Requantize<uint8_t>(
const int32_t* src,
uint8_t* dst,
const int len,
const int64_t len,
const RequantizationParams& params,
int thread_id,
int num_threads) {
int i_begin, i_end;
int64_t i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
fbgemmHasAvx2Support()) {
RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params);
} else {
for (int i = i_begin; i < i_end; ++i) {
for (int64_t i = i_begin; i < i_end; ++i) {
dst[i] = Requantize<uint8_t>(src[i], params);
}
}
Expand All @@ -437,18 +437,18 @@ template <typename T>
FBGEMM_API void RequantizeFixedPoint(
const std::int32_t* src,
T* dst,
int len,
int64_t len,
const RequantizationParams& params,
int thread_id,
int num_threads) {
int i_begin, i_end;
int64_t i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 &&
cpuinfo_initialize() && fbgemmHasAvx2Support()) {
RequantizeFixedPointAvx2(
&src[i_begin], &dst[i_begin], i_end - i_begin, params);
} else {
for (int i = i_begin; i < i_end; ++i) {
for (int64_t i = i_begin; i < i_end; ++i) {
dst[i] = RequantizeFixedPoint<T>(src[i], params);
}
}
Expand All @@ -459,13 +459,13 @@ FBGEMM_API void RequantizeFixedPoint(
FBGEMM_API void RequantizeFixedPoint<T>( \
const int32_t* src, \
T* dst, \
const int len, \
const int64_t len, \
const RequantizationParams& params, \
int thread_id, \
int num_threads) { \
int i_begin, i_end; \
int64_t i_begin, i_end; \
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
for (int i = i_begin; i < i_end; ++i) { \
for (int64_t i = i_begin; i < i_end; ++i) { \
dst[i] = RequantizeFixedPoint<T>(src[i], params); \
} \
}
Expand All @@ -477,19 +477,19 @@ template <>
FBGEMM_API void RequantizeFixedPoint<uint8_t>(
const int32_t* src,
uint8_t* dst,
const int len,
const int64_t len,
const RequantizationParams& params,
int thread_id,
int num_threads) {
int i_begin, i_end;
int64_t i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);

if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
fbgemmHasAvx2Support()) {
RequantizeFixedPointAvx2(
&src[i_begin], &dst[i_begin], i_end - i_begin, params);
} else {
for (int i = i_begin; i < i_end; ++i) {
for (int64_t i = i_begin; i < i_end; ++i) {
dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params);
}
}
Expand Down
Loading

0 comments on commit f7050c9

Please sign in to comment.