forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCScanUtils.cuh
116 lines (89 loc) · 2.76 KB
/
THCScanUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#ifndef THC_SCAN_UTILS_INC
#define THC_SCAN_UTILS_INC
#include "THCAsmUtils.cuh"
// Collection of in-kernel scan / prefix sum utilities
// Inclusive prefix sum using shared memory
template <typename T, bool KillWARDependency>
__device__ void inclusivePrefixSum(T* smem, T in, T* out) {
// FIXME: this is a slow, simple implementation; need up/down sweep,
// prevent smem conflicts
smem[threadIdx.x] = in;
__syncthreads();
for (int offset = 1; offset < blockDim.x; offset *= 2) {
T val = 0;
if (threadIdx.x >= offset) {
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
}
__syncthreads();
if (threadIdx.x >= offset) {
smem[threadIdx.x] = val;
}
__syncthreads();
}
*out = smem[threadIdx.x];
// Prevent write-after-read dependencies on smem usage above if necessary
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum using shared memory
template <typename T, bool KillWARDependency>
__device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) {
// FIXME: crappy implementation
// We kill write-after-read dependencies separately below, hence the `false`
inclusivePrefixSum<T, false>(smem, in, out);
*out -= in;
*carry = smem[blockDim.x - 1];
// Prevent write-after-read dependencies on smem usage above if necessary
if (KillWARDependency) {
__syncthreads();
}
}
// Inclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency>
__device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) {
// Within-warp, we use warp voting.
T vote = __ballot(in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
int warp = threadIdx.x / 32;
// Per each warp, write out a value
if (getLaneId() == 0) {
smem[warp] = carry;
}
__syncthreads();
// Sum across warps in one thread. This appears to be faster than a
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) {
T v = smem[i];
smem[i] += current;
current += v;
}
}
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
index += smem[warp - 1];
}
*out = index;
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency>
__device__ void exclusiveBinaryPrefixSum(T* smem, bool in, T* out, T* carry) {
inclusiveBinaryPrefixSum<T, false>(smem, in, out);
// Inclusive to exclusive
*out -= (T) in;
// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / 32) - 1];
if (KillWARDependency) {
__syncthreads();
}
}
#endif // THC_SCAN_UTILS_INC