-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensordispatch.cc
68 lines (55 loc) · 1.6 KB
/
tensordispatch.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/tensordispatch.cc
* \brief Adapter library caller
*/
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#else // !WIN32
#include <dlfcn.h>
#endif // WIN32
#include <cstring>
namespace dgl {
namespace runtime {
constexpr const char *TensorDispatcher::names_[];
bool TensorDispatcher::Load(const char *path) {
CHECK(!available_) << "The tensor adapter can only load once.";
if (path == nullptr || strlen(path) == 0)
// does not have dispatcher library; all operators fall back to DGL's implementation
return false;
#if defined(WIN32) || defined(_WIN32)
handle_ = LoadLibrary(path);
if (!handle_)
return false;
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i]));
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#else // !WIN32
handle_ = dlopen(path, RTLD_LAZY);
if (!handle_) {
DLOG(WARNING) << "TensorDispatcher: dlopen failed: " << dlerror();
return false;
}
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = dlsym(handle_, names_[i]);
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#endif // WIN32
available_ = true;
return true;
}
TensorDispatcher::~TensorDispatcher() {
if (handle_) {
#if defined(WIN32) || defined(_WIN32)
FreeLibrary(handle_);
#else // !WIN32
dlclose(handle_);
#endif // WIN32
}
}
}; // namespace runtime
}; // namespace dgl