-
Notifications
You must be signed in to change notification settings - Fork 43
/
kernel.h
68 lines (57 loc) · 1.64 KB
/
kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#pragma once
#include "stats.h"
#include "clwrap.h"
#include "timeutil.h"
#include "common.h"
#include <string>
#include <vector>
#include <memory>
class Kernel {
Holder<cl_kernel> kernel;
cl_queue queue;
int workGroups;
string name;
u64 timeSum;
u64 nCalls;
bool doTime;
int groupSize;
Stats stats;
template<typename T> void setArgs(int pos, const T &arg) { ::setArg(kernel.get(), pos, arg); }
template<typename T, typename... Args> void setArgs(int pos, const T &arg, const Args &...tail) {
setArgs(pos, arg);
setArgs(pos + 1, tail...);
}
public:
Kernel(cl_program program, cl_queue q, cl_device_id device, int workGroups, const std::string &name, bool doTime) :
kernel(makeKernel(program, name.c_str())),
queue(q),
workGroups(workGroups),
name(name),
doTime(doTime),
groupSize(getWorkGroupSize(kernel.get(), device, name.c_str()))
{
// assert((workSize % groupSize == 0) || (log("%s\n", name.c_str()), false));
}
template<typename... Args> void setFixedArgs(int pos, Args &...tail) { setArgs(pos, tail...); }
template<typename... Args> void operator()(const Args &...args) {
setArgs(0, args...);
run(workGroups);
}
void run(u32 nWorkGroups) {
if (doTime) {
finish(queue);
Timer timer;
::run(queue, kernel.get(), groupSize, nWorkGroups * groupSize, name);
finish(queue);
stats.add(timer.deltaMicros());
} else {
::run(queue, kernel.get(), groupSize, nWorkGroups * groupSize, name);
}
}
string getName() { return name; }
StatsInfo resetStats() {
StatsInfo ret = stats.getStats();
stats.reset();
return ret;
}
};