Skip to content

Commit

Permalink
BlasKernel: Improve gemm's inner dot product when a is transposed (py…
Browse files Browse the repository at this point in the history
…torch#80977)

`gemm_transab_` accumulates the sum in the output, despite the inner
loop being over a single output element. This changes it to accumulate
in a register, which also avoids early truncation for bfloat16.

I've also factored out a generic `sum` function that can be shared
with `gemm_transa_` to handle unrolling and multiple accumulators.

I have benchmarked addmm for bfloat16 with shapes
(320,600) X (600,320) and for both layouts I see a significant
speedup.

|  layout  | Before (ms) | After (ms) |
|----------|-------------|------------|
| transa   | 71.5        | 31         |
| transab  | 249         | 35         |
Pull Request resolved: pytorch#80977
Approved by: https://github.com/ngimel
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Oct 9, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a45fead commit 753536b
Showing 1 changed file with 42 additions and 25 deletions.
67 changes: 42 additions & 25 deletions aten/src/ATen/native/cpu/BlasKernel.cpp
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/CPUBlas.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h>

namespace at {
namespace native {
@@ -30,6 +31,29 @@ void scale_(int64_t m, int64_t n, opmath_t alpha, scalar_t *a, int64_t lda) {
}
}

template <typename Func>
auto sum(int64_t N, Func f) {
constexpr int ilp_factor = 4;
using acc_t = decltype(f(0));

// Calculate independent partial sums then add together at the end
std::array<acc_t, ilp_factor> partial_sums{};

int64_t i = 0;
for (; i + ilp_factor <= N; i += ilp_factor) {
c10::ForcedUnroll<ilp_factor>{}([&](int k) {
partial_sums[k] += f(i + k);
});
}
for (; i < N; ++i) {
partial_sums[0] += f(i);
}
for (int k = 1; k < ilp_factor; ++k) {
partial_sums[0] += partial_sums[k];
}
return partial_sums[0];
}


template <typename scalar_t, typename opmath_t>
void gemm_notrans_(
@@ -73,15 +97,15 @@ void gemm_transa_(
for (const auto i : c10::irange(m)) {
const scalar_t *b_ = b;
for (const auto j : c10::irange(n)) {
opmath_t sum = 0;
for (const auto l : c10::irange(k)) {
sum += static_cast<opmath_t>(a_[l]) * static_cast<opmath_t>(b_[l]);
}
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a_[l]) * static_cast<opmath_t>(b_[l]);
});
b_ += ldb;
if (beta == scalar_t(0))
c[j*ldc+i] = alpha*sum;
else
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
if (beta == opmath_t(0)) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
@@ -124,26 +148,19 @@ void gemm_transab_(
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
// c *= beta
scale_(m, n, beta, c, ldc);

// c += alpha * (a.T @ b.T)
// c = beta * c + alpha * (a.T @ b.T)
for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(n)) {
int64_t l_k = k / 4;
for (const auto l_l : c10::irange(l_k)) {
c[j * ldc + i] += a[i * lda + l_l * 4 + 0] //
* (b[(l_l * 4 + 0) * ldb + j] * alpha);
c[j * ldc + i] += a[i * lda + l_l * 4 + 1] //
* (b[(l_l * 4 + 1) * ldb + j] * alpha);
c[j * ldc + i] += a[i * lda + l_l * 4 + 2] //
* (b[(l_l * 4 + 2) * ldb + j] * alpha);
c[j * ldc + i] += a[i * lda + l_l * 4 + 3] //
* (b[(l_l * 4 + 3) * ldb + j] * alpha);
const auto dot = sum(k, [&](int64_t l) -> opmath_t {
return static_cast<opmath_t>(a[i * lda + l]) *
static_cast<opmath_t>(b[l * ldb + j]);
});

if (beta == opmath_t(0)) {
c[j * ldc + i] = alpha * dot;
} else {
c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot;
}
int64_t l = l_k * 4;
for (; l < k; l++)
c[j * ldc + i] += a[i * lda + l] * (b[l * ldb + j] * alpha);
}
}
}

0 comments on commit 753536b

Please sign in to comment.