forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparse_attn.py
82 lines (64 loc) · 2.92 KB
/
sparse_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import OpBuilder
try:
from packaging import version as pkg_version
except ImportError:
pkg_version = None
class SparseAttnBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
NAME = "sparse_attn"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
def sources(self):
return ['csrc/sparse_attention/utils.cpp']
def cxx_args(self):
return ['-O2', '-fopenmp']
def is_compatible(self, verbose=True):
# Check to see if llvm and cmake are installed since they are dependencies
#required_commands = ['llvm-config|llvm-config-9', 'cmake']
#command_status = list(map(self.command_exists, required_commands))
#deps_compatible = all(command_status)
if self.is_rocm_pytorch():
self.warning(f'{self.NAME} is not compatible with ROCM')
return False
try:
import torch
except ImportError:
self.warning(f"unable to import torch, please install it first")
return False
# torch-cpu will not have a cuda version
if torch.version.cuda is None:
cuda_compatible = False
self.warning(f"{self.NAME} cuda is not available from torch")
else:
major, minor = torch.version.cuda.split('.')[:2]
cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11)
if not cuda_compatible:
self.warning(f"{self.NAME} requires CUDA version 10.1+")
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5)
if not torch_compatible:
self.warning(
f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')
try:
import triton
except ImportError:
# auto-install of triton is broken on some systems, reverting to manual install for now
# see this issue: https://github.com/microsoft/DeepSpeed/issues/1710
self.warning(f"please install triton==1.0.0 if you want to use sparse attention")
return False
if pkg_version:
installed_triton = pkg_version.parse(triton.__version__)
triton_mismatch = installed_triton != pkg_version.parse("1.0.0")
else:
installed_triton = triton.__version__
triton_mismatch = installed_triton != "1.0.0"
if triton_mismatch:
self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")
return False
return super().is_compatible(verbose) and torch_compatible and cuda_compatible