Skip to content

Commit

Permalink
Avoid double-initializing partial accumulate results.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 9, 2021
1 parent 1043661 commit 8705711
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,21 @@ function aggregate_partial_scan(op::Function, output::AbstractArray,
# iterate the other dimensions using the remaining block dimensions
j = (blockIdx().z-1i32) * gridDim().y + blockIdx().y

@inbounds if block > 1 && i <= length(Rdim) && j <= length(Rother)
@inbounds if i <= length(Rdim) && j <= length(Rother)
I = Rother[j]
Ipre = Rpre[I[1]]
Ipost = Rpost[I[2]]

val = op(aggregates[Ipre, block-1, Ipost], output[Ipre, i, Ipost])
val = if block > 1
op(aggregates[Ipre, block-1, Ipost], output[Ipre, i, Ipost])
else
output[Ipre, i, Ipost]
end

if init !== nothing
val = op(something(init), val)
end

output[Ipre, i, Ipost] = val
end

Expand Down Expand Up @@ -163,10 +169,11 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true)))
else
# perform partial scans across the scanning dimension
# NOTE: don't set init here to avoid applying the value multiple times
partial = prevpow(2, kernel_config.threads)
blocks_dim = cld(length(Rdim), partial)
@cuda(threads=partial, blocks=(blocks_dim, blocks_other...), shmem=2*partial*sizeof(T),
partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true)))
partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, nothing, Val(true)))

# get the total of each thread block (except the first) of the partial scans
aggregates = fill(neutral, Base.setindex(size(input), blocks_dim, dims))
Expand Down
3 changes: 2 additions & 1 deletion test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ end
@testset "accumulate" begin
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
@test testf(x->accumulate(+, x), rand(n))
@test testf((x,y)->accumulate(+, x; init=y), rand(n), rand())
end

# multidimensional
Expand All @@ -307,7 +308,7 @@ end
for (sizes, dims) in ((2,) => 2,
(3,4,5) => 2,
(1, 70, 50, 20) => 3)
@test testf(x->accumulate(+, x; dims=dims, init=100.), rand(Int, sizes))
@test testf((x,y)->accumulate(+, x; dims=dims, init=y), rand(Int, sizes), rand(Int))
end

# in place
Expand Down

0 comments on commit 8705711

Please sign in to comment.