Skip to content

Commit

Permalink
Add a generic fused softmax (NVIDIA#1440)
Browse files Browse the repository at this point in the history
* new kernel

Signed-off-by: Yi Dong <[email protected]>

* added the unit tests

Signed-off-by: Yi Dong <[email protected]>

* clean up unittest

Signed-off-by: Yi Dong <[email protected]>

* use float

Signed-off-by: Yi Dong <[email protected]>

* more clean up

Signed-off-by: Yi Dong <[email protected]>

* remove the long seq test case
  • Loading branch information
yidong72 authored Aug 5, 2022
1 parent 31cbdd1 commit cd0a1f1
Show file tree
Hide file tree
Showing 6 changed files with 731 additions and 0 deletions.
54 changes: 54 additions & 0 deletions apex/transformer/functional/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,33 @@ def scaled_masked_softmax(inputs, mask, scale):
return ScaledMaskedSoftmax.apply(*args)


class GenericScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import generic_scaled_masked_softmax_cuda

scale_t = torch.tensor([scale])
softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

@staticmethod
def backward(ctx, output_grads):
import generic_scaled_masked_softmax_cuda_new

softmax_results, scale_t = ctx.saved_tensors

input_grads = generic_scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None


def generic_scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return GenericScaledMaskedSoftmax.apply(*args)


class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Expand Down Expand Up @@ -209,3 +236,30 @@ def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda

return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)

class GenericFusedScaleMaskSoftmax(FusedScaleMaskSoftmax):
"""
Generic version of FusedSacleMaskSoftmax.
It removes the seq-len limitations and has slight performance degragation compared with FusedScaleMaskSoftmax
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""

def __init__(
self, input_in_fp16, input_in_bf16, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale,
):
super().__init__(input_in_fp16, input_in_bf16, AttnMaskType.padding, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale)
self.scaled_masked_softmax_fusion = generic_scaled_masked_softmax

def is_kernel_available(self, mask, b, np, sq, sk):
if self.scaled_masked_softmax_fusion and 0 < sk: # user want to fuse # sk must be 1 ~
return True
return False
83 changes: 83 additions & 0 deletions csrc/megatron/generic_scaled_masked_softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* coding=utf-8
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>

namespace multihead_attn
{
namespace fused_softmax
{
namespace generic_scaled_masked_softmax
{

torch::Tensor fwd_cuda(
torch::Tensor const &input,
torch::Tensor const &mask,
float scale_factor);

torch::Tensor bwd_cuda(
torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
float scale_factor);

torch::Tensor fwd(
torch::Tensor const &input,
torch::Tensor const &mask,
float scale_factor)
{
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");

return fwd_cuda(input, mask, scale_factor);
}

torch::Tensor bwd(
torch::Tensor const &output_grads,
torch::Tensor const &softmax_results,
float scale_factor)
{

AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");

AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");

return bwd_cuda(output_grads, softmax_results, scale_factor);
}

} // end namespace generic_scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");

m.def("backward",
&multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
Loading

0 comments on commit cd0a1f1

Please sign in to comment.