Skip to content

Commit

Permalink
Extend softmax fusion seq length to 16384 (NVIDIA#1558)
Browse files Browse the repository at this point in the history
Signed-off-by: yaoyu-33 <[email protected]>
  • Loading branch information
yaoyu-33 authored Jan 19, 2023
1 parent 91fcaaf commit e1e0531
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 8 deletions.
4 changes: 2 additions & 2 deletions apex/transformer/functional/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def is_kernel_available(self, mask, b, np, sq, sk):
self.attn_mask_type == AttnMaskType.causal
or self.attn_mask_type == AttnMaskType.padding
)
and 16 < sk <= 4096 # sk must be 16 ~ 4096
and 16 < sk <= 16384 # sk must be 16 ~ 16384
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
Expand Down
10 changes: 9 additions & 1 deletion csrc/megatron/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ void dispatch_scaled_softmax_forward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 16384 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -530,6 +530,14 @@ void dispatch_scaled_softmax_forward(
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 13: // 8192
scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 14: // 16384
scaled_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/megatron/scaled_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(key_seq_len <= 16384);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
Expand Down
2 changes: 1 addition & 1 deletion csrc/megatron/scaled_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(key_seq_len <= 16384);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);

// Output
Expand Down
28 changes: 26 additions & 2 deletions csrc/megatron/scaled_upper_triang_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 16384 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -415,6 +415,18 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand All @@ -431,7 +443,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 16384 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -506,6 +518,18 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ torch::Tensor fwd_cuda(
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
TORCH_INTERNAL_ASSERT(seq_len <= 16384);

// Output
auto act_options = input.options().requires_grad(false);
Expand Down

0 comments on commit e1e0531

Please sign in to comment.