Skip to content

Commit

Permalink
[kernel] fixed repeated loading of kernels (hpcaitech#2549)
Browse files Browse the repository at this point in the history
* [kernel] fixed repeated loading of kernels

* polish code

* polish code
  • Loading branch information
FrankLeeeee authored Feb 3, 2023
1 parent 8438c35 commit dd14783
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 46 deletions.
4 changes: 3 additions & 1 deletion colossalai/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax

__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
25 changes: 13 additions & 12 deletions colossalai/kernel/cuda_native/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,38 @@
from torch.nn import init
from torch.nn.parameter import Parameter

from colossalai.kernel.op_builder.layernorm import LayerNormBuilder

try:
from colossalai._C import layer_norm
except ImportError:
layer_norm = None


class FusedLayerNormAffineFunction(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps):
try:
from colossalai._C import layer_norm
except ImportError:
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()

ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()

global layer_norm
if layer_norm is None:

layer_norm = LayerNormBuilder().load()
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)

return output

@staticmethod
@custom_bwd
def backward(ctx, grad_output):
try:
from colossalai._C import layer_norm
except ImportError:
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()

input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
Expand Down
47 changes: 19 additions & 28 deletions colossalai/kernel/cuda_native/scaled_softmax.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""This code from NVIDIA Megatron
with some changes. """

import enum

import torch
import torch.nn as nn

from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder

try:
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
except ImportError:
scaled_masked_softmax = None
scaled_upper_triang_masked_softmax = None


class AttnMaskType(enum.Enum):
padding = 1
Expand All @@ -23,7 +29,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, inputs, scale):
from colossalai.kernel import scaled_upper_triang_masked_softmax
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load()

scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
Expand All @@ -33,8 +41,6 @@ def forward(ctx, inputs, scale):

@staticmethod
def backward(ctx, output_grads):
from colossalai.kernel import scaled_upper_triang_masked_softmax

softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])

Expand All @@ -52,30 +58,23 @@ class ScaledMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, inputs, mask, scale):
try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()

scale_t = torch.tensor([scale])

# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()

softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results

@staticmethod
def backward(ctx, output_grads):
try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()

softmax_results, scale_t = ctx.saved_tensors

input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
return input_grads, None, None, None


class FusedScaleMaskSoftmax(nn.Module):
Expand Down Expand Up @@ -113,14 +112,6 @@ def __init__(
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale

try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
self.scaled_masked_softmax = scaled_masked_softmax

assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"

def forward(self, input, mask):
Expand Down Expand Up @@ -186,4 +177,4 @@ def forward_torch_softmax(self, input, mask):
return probs

def get_batch_per_block(self, sq, sk, b, np):
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
29 changes: 24 additions & 5 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@
from typing import List


def print_rank_0(message):
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True

if is_main_rank:
print(message)


class Builder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
Expand Down Expand Up @@ -117,7 +134,7 @@ def load(self, verbose=True):
try:
op_module = self.import_op()
if verbose:
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
except ImportError:
# construct the build directory
import torch
Expand All @@ -130,9 +147,11 @@ def load(self, verbose=True):
Path(build_directory).mkdir(parents=True, exist_ok=True)

if verbose:
print("=========================================================================================")
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print("=========================================================================================")
print_rank_0(
"=========================================================================================")
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print_rank_0(
"=========================================================================================")

# load the kernel
op_module = load(name=self.name,
Expand All @@ -146,7 +165,7 @@ def load(self, verbose=True):

build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")

return op_module

Expand Down

0 comments on commit dd14783

Please sign in to comment.