-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeatgraph.cc
81 lines (71 loc) · 2.2 KB
/
featgraph.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/**
* Copyright (c) 2020 by Contributors
* @file featgraph/src/featgraph.cc
* @brief FeatGraph kernels.
*/
#include <dmlc/logging.h>
#include <featgraph.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
namespace dgl {
namespace featgraph {
/* @brief Singleton that loads the featgraph module. */
class FeatGraphModule {
public:
static FeatGraphModule* Global() {
static FeatGraphModule inst;
return &inst;
}
void Load(const std::string& path) {
mod = tvm::runtime::Module::LoadFromFile(path);
}
inline tvm::runtime::ModuleNode* Get() {
auto ret = mod.operator->();
if (!ret) {
LOG(FATAL) << "FeatGraph module have not been loaded. "
<< "Please set path of featgraph shared library.";
}
return ret;
}
private:
tvm::runtime::Module mod;
FeatGraphModule() {}
};
/* @brief Load Featgraph module from given path. */
void LoadFeatGraphModule(const std::string& path) {
FeatGraphModule::Global()->Load(path);
}
/* @brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) {
switch (t.code) {
case 0U:
return "int" + std::to_string(t.bits);
case 1U:
return "uint" + std::to_string(t.bits);
case 2U:
return "float" + std::to_string(t.bits);
case 3U:
return "bfloat" + std::to_string(t.bits);
default:
LOG(FATAL) << "Type code " << t.code << " not recognized";
}
}
/* @brief Get operator filename. */
inline std::string GetOperatorName(
const std::string& base_name, const DLDataType& dtype,
const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
}
/* @brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(
DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* rhs, DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName(
"SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr) f(row, col, lhs, rhs, out);
}
} // namespace featgraph
} // namespace dgl