-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.cc
161 lines (144 loc) · 4.52 KB
/
registry.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*!
* Copyright (c) 2017 by Contributors
* \file registry.cc
* \brief The global registry of packed function.
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array>
#include "runtime_base.h"
namespace dgl {
namespace runtime {
struct Registry::Manager {
// map storing the functions.
// We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python)
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, Registry*> fmap;
// vtable for extension type
std::array<ExtTypeVTable, kExtEnd> ext_vtable;
// mutex
std::mutex mutex;
Manager() {
for (auto& x : ext_vtable) {
x.destroy = nullptr;
}
}
static Manager* Global() {
static Manager inst;
return &inst;
}
};
Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
func_ = f;
return *this;
}
Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*)
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
Registry* r = new Registry();
r->name_ = name;
m->fmap[name] = r;
return *r;
} else {
CHECK(override)
<< "Global PackedFunc " << name << " is already registered";
return *it->second;
}
}
bool Registry::Remove(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return false;
m->fmap.erase(it);
return true;
}
const PackedFunc* Registry::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return nullptr;
return &(it->second->func_);
}
std::vector<std::string> Registry::ListNames() {
Manager* m = Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
std::vector<std::string> keys;
keys.reserve(m->fmap.size());
for (const auto &kv : m->fmap) {
keys.push_back(kv.first);
}
return keys;
}
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
CHECK(vt->destroy != nullptr)
<< "Extension type not registered";
return vt;
}
ExtTypeVTable* ExtTypeVTable::RegisterInternal(
int type_code, const ExtTypeVTable& vt) {
CHECK(type_code > kExtBegin && type_code < kExtEnd);
Registry::Manager* m = Registry::Manager::Global();
std::lock_guard<std::mutex> lock(m->mutex);
ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
pvt[0] = vt;
return pvt;
}
} // namespace runtime
} // namespace dgl
/*! \brief entry to to easily hold returning information */
struct DGLFuncThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
};
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<DGLFuncThreadLocalEntry> DGLFuncThreadLocalStore;
int DGLExtTypeFree(void* handle, int type_code) {
API_BEGIN();
dgl::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
API_END();
}
int DGLFuncRegisterGlobal(
const char* name, DGLFunctionHandle f, int override) {
API_BEGIN();
dgl::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
API_END();
}
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
API_BEGIN();
const dgl::runtime::PackedFunc* fp =
dgl::runtime::Registry::Get(name);
if (fp != nullptr) {
*out = new dgl::runtime::PackedFunc(*fp); // NOLINT(*)
} else {
*out = nullptr;
}
API_END();
}
int DGLFuncListGlobalNames(int *out_size,
const char*** out_array) {
API_BEGIN();
DGLFuncThreadLocalEntry *ret = DGLFuncThreadLocalStore::Get();
ret->ret_vec_str = dgl::runtime::Registry::ListNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}