forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pack_featgraph.py
40 lines (31 loc) · 976 Bytes
/
pack_featgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
""" Export featgraph kernels to a shared library. """
import tvm
from sddmm import sddmm_tree_reduction_gpu
def get_sddmm_kernels_gpu(idtypes, dtypes):
"""
Parameters
----------
idtypes: List[str]
Possible index types.
dtypes: List[str]
Possible data types.
Returns
-------
List[IRModule]:
The list of IRModules.
"""
ret = []
# SDDMM Tree Reduction
for dtype in dtypes:
for idtype in idtypes:
ret.append(sddmm_tree_reduction_gpu(idtype, dtype))
return ret
if __name__ == '__main__':
binary_path = 'libfeatgraph_kernels.so'
kernels = []
idtypes = ['int32', 'int64']
dtypes = ['float16', 'float64', 'float32', 'int32', 'int64']
kernels += get_sddmm_kernels_gpu(idtypes, dtypes)
# build kernels and export the module to libfeatgraph_kernels.so
module = tvm.build(kernels, target='cuda', target_host='llvm')
module.export_library(binary_path)