forked from XuezheMax/megalodon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
65 lines (54 loc) · 1.94 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
PATH = os.path.dirname(os.path.abspath(__file__))
CSRCS = [
os.path.join(PATH, "megalodon/csrc/blas.cc"),
os.path.join(PATH, "megalodon/csrc/megalodon_extension.cc"),
os.path.join(PATH, "megalodon/csrc/ops/attention.cc"),
os.path.join(PATH, "megalodon/csrc/ops/attention_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/attention_softmax.cc"),
os.path.join(PATH, "megalodon/csrc/ops/attention_softmax_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/ema_hidden.cc"),
os.path.join(PATH, "megalodon/csrc/ops/ema_hidden_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/ema_parameters.cc"),
os.path.join(PATH, "megalodon/csrc/ops/ema_parameters_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/fftconv.cc"),
os.path.join(PATH, "megalodon/csrc/ops/fftconv_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/sequence_norm.cc"),
os.path.join(PATH, "megalodon/csrc/ops/sequence_norm_kernel.cu"),
os.path.join(PATH, "megalodon/csrc/ops/timestep_norm.cc"),
os.path.join(PATH, "megalodon/csrc/ops/timestep_norm_kernel.cu"),
]
INCLUDE_DIRS = [
os.path.join(PATH, "megalodon/csrc"),
]
CXX_FLAGS = [
"-O3",
"-std=c++17",
]
NVCC_FLAGS = [
# "-rdc=true",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
# "--use_fast_math",
"--threads",
"4",
]
def main():
setup(
name='megalodon',
version="0.0.1",
ext_modules=[
CUDAExtension("megalodon_extension",
CSRCS,
include_dirs=INCLUDE_DIRS,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": CXX_FLAGS + NVCC_FLAGS,
})
],
cmdclass={'build_ext': BuildExtension},
)
if __name__ == "__main__":
main()