forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththread_storage_scope.h
207 lines (196 loc) · 5.65 KB
/
thread_storage_scope.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
/*!
* Copyright (c) 2017 by Contributors
* \file thread_storage_scope.h
* \brief Extract thread axis configuration from DGLArgs.
*/
#ifndef DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h>
#include <string>
#include <vector>
namespace dgl {
namespace runtime {
/*!
* \brief Memory hierachy rank in the storage system
* \note The global rank and shared rank have one to one
* correspondence to the thread rank.
*/
enum class StorageRank {
/*! \brief global memory */
kGlobal = 0,
/*! \brief shared memory among thread group */
kShared = 1,
/*!
* \brief reserved for warp memory.
* This is only used by programming model.
* There is no such memory usually in GPU.
* Instead, we can simulate it by registers and shuffle.
*/
kWarp = 2,
/*! \brief thread local memory */
kLocal = 3
};
/*!
* \param thread_scope_rank The thread scope rank
* \return default storage rank given the thread scope
*/
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
switch (thread_scope_rank) {
case -1: return StorageRank::kGlobal;
case 0: return StorageRank::kShared;
case 1: return StorageRank::kLocal;
default: {
LOG(FATAL) << "unknown rank";
return StorageRank::kGlobal;
}
}
}
/*! \brief class to represent storage scope */
struct StorageScope {
/*! \brief The rank of the storage */
StorageRank rank{StorageRank::kGlobal};
/*! \brief tag for special purpose memory. */
std::string tag;
// comparator
inline bool operator==(const StorageScope& other) const {
return rank == other.rank && tag == other.tag;
}
inline bool operator!=(const StorageScope& other) const {
return !(*this == other);
}
inline std::string to_string() const {
std::string ret;
switch (rank) {
case StorageRank::kGlobal: return "global" + tag;
case StorageRank::kShared: return "shared" + tag;
case StorageRank::kWarp: return "warp" + tag;
case StorageRank::kLocal: return "local" + tag;
default: LOG(FATAL) << "unknown storage scope"; return "";
}
}
/*!
* \brief make storage scope from string
* \param s The string to be parsed.
* \return The storage scope.
*/
static StorageScope make(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
r.rank = StorageRank::kShared;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 4, "warp") == 0) {
r.rank = StorageRank::kWarp;
r.tag = s.substr(4, std::string::npos);
} else if (s.compare(0, 5, "local") == 0) {
r.rank = StorageRank::kLocal;
r.tag = s.substr(5, std::string::npos);
} else {
LOG(FATAL) << "unknown storage scope " << s;
}
return r;
}
};
/*! \brief class to represent thread scope */
struct ThreadScope {
/*! \brief The rank of thread scope */
int rank{0};
/*! \brief the dimension index under the rank */
int dim_index{0};
/*!
* \brief make storage scope from string
* \param s The string to be parsed.
* \return The storage scope.
*/
static ThreadScope make(const std::string& s) {
ThreadScope r;
if (s == "vthread" || s == "cthread") {
// virtual thread at the same level as local
r.rank = 1;
r.dim_index = -1;
} else if (s.compare(0, 9, "blockIdx.") == 0) {
r.rank = 0;
r.dim_index = static_cast<int>(s[9] - 'x');
} else if (s.compare(0, 10, "threadIdx.") == 0) {
r.rank = 1;
r.dim_index = static_cast<int>(s[10] - 'x');
} else {
LOG(FATAL) << "Unknown threadscope " << s;
}
return r;
}
};
/*! \brief workload speccification */
struct ThreadWorkLoad {
// array, first three are thread configuration.
size_t work_size[6];
/*!
* \param i The block dimension.
* \return i-th block dim
*/
inline size_t block_dim(size_t i) const {
return work_size[i + 3];
}
/*!
* \param i The grid dimension.
* \return i-th grid dim
*/
inline size_t grid_dim(size_t i) const {
return work_size[i];
}
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
public:
void Init(size_t base,
const std::vector<std::string>& thread_axis_tags) {
base_ = base;
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
const std::string& tag = thread_axis_tags[i];
ThreadScope ts = ThreadScope::make(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
}
work_dim_ = 1;
for (int i = 0; i < 3; ++i) {
if (filled[i] || filled[i + 3]) {
work_dim_ = i + 1;
}
}
}
// extract workload from arguments.
ThreadWorkLoad Extract(DGLArgs x) const {
ThreadWorkLoad w;
std::fill(w.work_size, w.work_size + 6, 1);
for (size_t i = 0; i < arg_index_map_.size(); ++i) {
w.work_size[arg_index_map_[i]] =
static_cast<size_t>(x.values[base_ + i].v_int64);
}
return w;
}
// return the work dim
size_t work_dim() const {
return work_dim_;
}
private:
/*! \brief base axis */
size_t base_;
/*! \brief The worker dimension */
size_t work_dim_;
/*! \brief The index mapping. */
std::vector<uint32_t> arg_index_map_;
};
} // namespace runtime
} // namespace dgl
namespace std {
template <>
struct hash<::dgl::runtime::StorageScope> {
std::size_t operator()(const ::dgl::runtime::StorageScope& k) const {
return static_cast<size_t>(k.rank);
}
};
} // namespace std
#endif // DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_