forked from JuliaGPU/CUDA.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaccumulate.jl
213 lines (173 loc) · 7.18 KB
/
accumulate.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# scan and accumulate
## COV_EXCL_START
# partial scan of individual thread blocks within a grid
# work-efficient implementation after Blelloch (1990)
#
# number of threads needs to be a power-of-2
#
# performance TODOs:
# - shuffle
# - warp-aggregate atomics
# - the ND case is quite a bit slower than the 1D case (not using Cartesian indices,
# before 35fcbde1f2987023229034370b0c9091e18c4137). optimize or special-case?
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
Rdim, Rpre, Rpost, Rother, neutral, init,
::Val{inclusive}=Val(true)) where {T, inclusive}
threads = blockDim().x
thread = threadIdx().x
block = blockIdx().x
temp = CuDynamicSharedArray(T, (2*threads,))
# iterate the main dimension using threads and the first block dimension
i = (blockIdx().x-1i32) * blockDim().x + threadIdx().x
# iterate the other dimensions using the remaining block dimensions
j = (blockIdx().z-1i32) * gridDim().y + blockIdx().y
if j > length(Rother)
return
end
@inbounds begin
I = Rother[j]
Ipre = Rpre[I[1]]
Ipost = Rpost[I[2]]
end
# load input into shared memory (apply `op` to have the correct type)
@inbounds temp[thread] = if i <= length(Rdim)
op(neutral, input[Ipre, i, Ipost])
else
op(neutral, neutral)
end
# build sum in place up the tree
offset = 1
d = threads>>1
while d > 0
sync_threads()
@inbounds if thread <= d
ai = offset * (2*thread-1)
bi = offset * (2*thread)
temp[bi] = op(temp[ai], temp[bi])
end
offset *= 2
d >>= 1
end
# clear the last element
@inbounds if thread == 1
temp[threads] = neutral
end
# traverse down tree & build scan
d = 1
while d < threads
offset >>= 1
sync_threads()
@inbounds if thread <= d
ai = offset * (2*thread-1)
bi = offset * (2*thread)
t = temp[ai]
temp[ai] = temp[bi]
temp[bi] = op(t, temp[bi])
end
d *= 2
end
sync_threads()
# write results to device memory
@inbounds if i <= length(Rdim)
val = if inclusive
op(temp[thread], input[Ipre, i, Ipost])
else
temp[thread]
end
if init !== nothing
val = op(something(init), val)
end
output[Ipre, i, Ipost] = val
end
return
end
# aggregate the result of a partial scan by applying preceding block aggregates
function aggregate_partial_scan(op::Function, output::AbstractArray,
aggregates::AbstractArray, Rdim, Rpre, Rpost, Rother,
init)
threads = blockDim().x
thread = threadIdx().x
block = blockIdx().x
# iterate the main dimension using threads and the first block dimension
i = (blockIdx().x-1i32) * blockDim().x + threadIdx().x
# iterate the other dimensions using the remaining block dimensions
j = (blockIdx().z-1i32) * gridDim().y + blockIdx().y
@inbounds if i <= length(Rdim) && j <= length(Rother)
I = Rother[j]
Ipre = Rpre[I[1]]
Ipost = Rpost[I[2]]
val = if block > 1
op(aggregates[Ipre, block-1, Ipost], output[Ipre, i, Ipost])
else
output[Ipre, i, Ipost]
end
if init !== nothing
val = op(something(init), val)
end
output[Ipre, i, Ipost] = val
end
return
end
## COV_EXCL_STOP
function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T}
dims > 0 || throw(ArgumentError("dims must be a positive integer"))
inds_t = axes(input)
axes(output) == inds_t || throw(DimensionMismatch("shape of B must match A"))
dims > ndims(input) && return copyto!(output, input)
isempty(inds_t[dims]) && return output
# iteration domain across the main dimension
Rdim = CartesianIndices((size(input, dims),))
# iteration domain for the other dimensions
Rpre = CartesianIndices(size(input)[1:dims-1])
Rpost = CartesianIndices(size(input)[dims+1:end])
Rother = CartesianIndices((length(Rpre), length(Rpost)))
# determine how many threads we can launch for the scan kernel
kernel = @cuda launch=false partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true))
kernel_config = launch_configuration(kernel.fun; shmem=(threads)->2*threads*sizeof(T))
# determine the grid layout to cover the other dimensions
if length(Rother) > 1
dev = device()
max_other_blocks = attribute(dev, DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y)
blocks_other = (min(length(Rother), max_other_blocks),
cld(length(Rother), max_other_blocks))
else
blocks_other = (1, 1)
end
# does that suffice to scan the array in one go?
full = nextpow(2, length(Rdim))
if full <= kernel_config.threads
@cuda(threads=full, blocks=(1, blocks_other...), shmem=2*full*sizeof(T), name="scan",
partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true)))
else
# perform partial scans across the scanning dimension
# NOTE: don't set init here to avoid applying the value multiple times
partial = prevpow(2, kernel_config.threads)
blocks_dim = cld(length(Rdim), partial)
@cuda(threads=partial, blocks=(blocks_dim, blocks_other...), shmem=2*partial*sizeof(T),
partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, nothing, Val(true)))
# get the total of each thread block (except the first) of the partial scans
aggregates = fill(neutral, Base.setindex(size(input), blocks_dim, dims))
copyto!(aggregates, selectdim(output, dims, partial:partial:length(Rdim)))
# scan these totals to get totals for the entire partial scan
accumulate!(f, aggregates, aggregates; dims=dims)
# add those totals to the partial scan result
# NOTE: we assume that this kernel requires fewer resources than the scan kernel.
# if that does not hold, launch with fewer threads and calculate
# the aggregate block index within the kernel itself.
@cuda(threads=partial, blocks=(blocks_dim, blocks_other...),
aggregate_partial_scan(f, output, aggregates, Rdim, Rpre, Rpost, Rother, init))
unsafe_free!(aggregates)
end
return output
end
## Base interface
Base._accumulate!(op, output::AnyCuArray, input::AnyCuVector, dims::Nothing, init::Nothing) =
scan!(op, output, input; dims=1)
Base._accumulate!(op, output::AnyCuArray, input::AnyCuArray, dims::Integer, init::Nothing) =
scan!(op, output, input; dims=dims)
Base._accumulate!(op, output::AnyCuArray, input::CuVector, dims::Nothing, init::Some) =
scan!(op, output, input; dims=1, init=init)
Base._accumulate!(op, output::AnyCuArray, input::AnyCuArray, dims::Integer, init::Some) =
scan!(op, output, input; dims=dims, init=init)
Base.accumulate_pairwise!(op, result::AnyCuVector, v::AnyCuVector) = accumulate!(op, result, v)