Skip to content

Commit

Permalink
forward: support batch and channel dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
proger committed Jan 9, 2024
1 parent 5708b67 commit 818e7b5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
85 changes: 44 additions & 41 deletions accelerated_scan/warp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
#include <torch/extension.h>

template <int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock>
__global__ void scan(float* gates, float* tokens, float* result) {
__global__ void scan(
float* gates,
float* tokens,
float* result,
int batch_stride,
int dim_stride
) {
__shared__ float warpLastGate[kNWarpsPerBlock];
__shared__ float warpLastToken[kNWarpsPerBlock];

const int threadId = threadIdx.x;
const int warpId = threadId / kNThreadsPerWarp;
const int laneId = threadId % kNThreadsPerWarp;
const int offset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
const int warpId = threadIdx.x / kNThreadsPerWarp;
const int laneId = threadIdx.x % kNThreadsPerWarp;
constexpr int kWarpLast = kNThreadsPerWarp - 1;
constexpr int kThreadLast = kNStepsPerThread - 1;
constexpr float kEmptyGate = 1.0;
Expand All @@ -26,10 +32,10 @@ __global__ void scan(float* gates, float* tokens, float* result) {

#pragma unroll
for (int i = 0; i < kNStepsPerThread; ++i) {
float gate = gates[threadId * kNStepsPerThread + i];
float token = tokens[threadId * kNStepsPerThread + i];
float gate = gates[offset + threadIdx.x * kNStepsPerThread + i];
float token = tokens[offset + threadIdx.x * kNStepsPerThread + i];
if (i == 0) {
acc[i] = {threadId == 0 ? kEmptyGate : gate, token};
acc[i] = {threadIdx.x == 0 ? kEmptyGate : gate, token};
} else {
acc[i] = {acc[i - 1].x * gate, acc[i - 1].y * gate + token};
}
Expand Down Expand Up @@ -100,9 +106,9 @@ __global__ void scan(float* gates, float* tokens, float* result) {
#pragma unroll
for (int i = 0; i < kNStepsPerThread; ++i) {
if (warpId > 0) {
result[threadId * kNStepsPerThread + i] = warpLastToken[warpId-1] * acc[i].x + acc[i].y;
result[offset + threadIdx.x * kNStepsPerThread + i] = warpLastToken[warpId-1] * acc[i].x + acc[i].y;
} else {
result[threadId * kNStepsPerThread + i] = acc[i].y;
result[offset + threadIdx.x * kNStepsPerThread + i] = acc[i].y;
}
}
}
Expand All @@ -113,6 +119,12 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(tokens.is_cuda());
TORCH_CHECK(gates.is_cuda());
TORCH_CHECK(tokens.is_contiguous());
TORCH_CHECK(gates.is_contiguous());

const auto strides = tokens.strides();
const int batch_stride = strides[0];
const int dim_stride = strides[1];
TORCH_CHECK(tokens.stride(-1) == 1 || tokens.size(-1) == 1);
TORCH_CHECK(gates.stride(-1) == 1 || gates.size(-1) == 1);

Expand All @@ -122,81 +134,72 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te
const int seqlen = sizes[2];

auto stream = at::cuda::getCurrentCUDAStream().stream();

dim3 grid(1,1);
dim3 grid(batch_size, dim);
constexpr int kNThreadsPerWarp = 32;

if (seqlen == 32) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 1;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 64) {
constexpr int kNStepsPerThread = 2;
constexpr int kNWarpsPerBlock = 1;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 128) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 4;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 256) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 8;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 512) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 16;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 1024) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 2048) {
constexpr int kNStepsPerThread = 2;
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else if (seqlen == 4096) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
scan<kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock><<<grid, kNThreads, kNWarpsPerBlock * sizeof(float2), stream>>>(
gates.data_ptr<float>(),
tokens.data_ptr<float>(),
out.data_ptr<float>()
gates.data_ptr<float>(), tokens.data_ptr<float>(), out.data_ptr<float>(),
batch_stride, dim_stride
);
} else {
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 4096");
Expand Down
8 changes: 4 additions & 4 deletions tests/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
seqlens = [2**i for i in range(5, 13)]


def init(seed, seqlen=32, requires_grad=False):
def init(seed, batch_size=3, dim=1536, seqlen=32, requires_grad=False):
torch.manual_seed(seed)
gates = 0.999 + 0.001 * torch.rand(1, 1, seqlen, requires_grad=requires_grad, device="cuda")
tokens = torch.rand(1, 1, seqlen, requires_grad=requires_grad, device="cuda")
gates = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, requires_grad=requires_grad, device="cuda")
tokens = torch.rand(batch_size, dim, seqlen, requires_grad=requires_grad, device="cuda")
if requires_grad:
gates.retain_grad()
tokens.retain_grad()
Expand All @@ -24,7 +24,7 @@ def init(seed, seqlen=32, requires_grad=False):
@pytest.mark.parametrize("seqlen", seqlens)
@torch.inference_mode()
def test_eq_forward(seed, seqlen):
gates, tokens = init(seed, seqlen)
gates, tokens = init(seed, seqlen=seqlen)
out = scan(gates, tokens)
out_ref = scan_ref(gates, tokens)

Expand Down

0 comments on commit 818e7b5

Please sign in to comment.