Skip to content

Commit

Permalink
Rename delayed=true to launch=false and always return the kernel object.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 25, 2020
1 parent 86450b5 commit 06c0a44
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function pairwise_dist_gpu(lat::Vector{Float32}, lon::Vector{Float32})
# calculate the amount of dynamic shared memory for a 2D block size
get_shmem(threads) = 2 * sum(threads) * sizeof(Float32)

kernel = @cuda delayed=true pairwise_dist_kernel(lat_gpu, lon_gpu, rowresult_gpu, n)
kernel = @cuda launch=false pairwise_dist_kernel(lat_gpu, lon_gpu, rowresult_gpu, n)
config = launch_configuration(kernel.fun, shmem=threads->get_shmem(get_threads(threads)))

# convert to 2D block size and figure out appropriate grid size
Expand Down
2 changes: 1 addition & 1 deletion examples/peakflops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function peakflops(n::Integer=5000, dev::CuDevice=CuDevice(0))

len = prod(dims)

kernel = @cuda delayed=true kernel_100fma(d_a, d_b, d_c, d_out)
kernel = @cuda launch=false kernel_100fma(d_a, d_b, d_c, d_out)
config = launch_configuration(kernel.fun)
threads = Base.min(len, config.threads)
blocks = cld(len, threads)
Expand Down
2 changes: 1 addition & 1 deletion perf/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ group["launch"] = @benchmarkable @cuda dummy_kernel()

wanted_threads = 10000
group["occupancy"] = @benchmarkable begin
kernel = @cuda delayed=true dummy_kernel()
kernel = @cuda launch=false dummy_kernel()
config = launch_configuration(kernel.fun)
threads = Base.min($wanted_threads, config.threads)
blocks = cld($wanted_threads, threads)
Expand Down
2 changes: 1 addition & 1 deletion src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
Rother = CartesianIndices((length(Rpre), length(Rpost)))

# determine how many threads we can launch for the scan kernel
kernel = @cuda delayed=true partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true))
kernel = @cuda launch=false partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(true))
kernel_config = launch_configuration(kernel.fun; shmem=(threads)->2*threads*sizeof(T))

# determine the grid layout to cover the other dimensions
Expand Down
31 changes: 15 additions & 16 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ managed automatically using `cudaconvert`. Finally, a call to `cudacall` is
performed, scheduling a kernel launch on the current CUDA context.
Several keyword arguments are supported that influence the behavior of `@cuda`.
- `dynamic`: use dynamic parallelism to launch device-side kernels
- `delayed`: return a callable kernel object instead of launching the kernel directly
- `launch`: whether to launch this kernel, defaults to `true`. If `false` the returned
kernel object should be launched by calling it and passing arguments again.
- `dynamic`: use dynamic parallelism to launch device-side kernels, defaults to `false`.
- arguments that influence kernel compilation: see [`cufunction`](@ref) and
[`dynamic_cufunction`](@ref)
- arguments that influence kernel launch: see [`CUDA.HostKernel`](@ref) and
Expand All @@ -38,7 +39,7 @@ macro cuda(ex...)
# group keyword argument
macro_kwargs, compiler_kwargs, call_kwargs, other_kwargs =
split_kwargs(kwargs,
[:dynamic, :delayed],
[:dynamic, :launch],
[:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name],
[:cooperative, :blocks, :threads, :config, :shmem, :stream])
if !isempty(other_kwargs)
Expand All @@ -48,21 +49,21 @@ macro cuda(ex...)

# handle keyword arguments that influence the macro's behavior
dynamic = false
delayed = false
launch = true
for kwarg in macro_kwargs
key,val = kwarg.args
if key == :dynamic
isa(val, Bool) || throw(ArgumentError("`dynamic` keyword argument to @cuda should be a constant value"))
dynamic = val::Bool
elseif key == :delayed
isa(val, Bool) || throw(ArgumentError("`delayed` keyword argument to @cuda should be a constant value"))
delayed = val::Bool
elseif key == :launch
isa(val, Bool) || throw(ArgumentError("`launch` keyword argument to @cuda should be a constant value"))
launch = val::Bool
else
throw(ArgumentError("Unsupported keyword argument '$key'"))
end
end
if delayed && !isempty(call_kwargs)
error("delayed @cuda does not support these call-time keyword arguments; use them when calling the kernel")
if !launch && !isempty(call_kwargs)
error("@cuda with launch=false does not support launch-time keyword arguments; use them when calling the kernel")
end

# FIXME: macro hygiene wrt. escaping kwarg values (this broke with 1.5)
Expand All @@ -81,11 +82,10 @@ macro cuda(ex...)
local $kernel_args = ($(var_exprs...),)
local $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
local $kernel = $dynamic_cufunction($f, $kernel_tt)
if $delayed
$kernel
else
if $launch
$kernel($kernel_args...; $(call_kwargs...))
end
$kernel
end)
else
# regular, host-side kernel launch
Expand All @@ -99,11 +99,10 @@ macro cuda(ex...)
local $kernel_args = map($cudaconvert, ($(var_exprs...),))
local $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
local $kernel = $cufunction($f, $kernel_tt; $(compiler_kwargs...))
if $delayed
$kernel
else
if $launch
$kernel($(var_exprs...); $(call_kwargs...))
end
$kernel
end
end)
end
Expand Down Expand Up @@ -207,7 +206,7 @@ end

@inline function cudacall(kernel::HostKernel, tt, args...; config=nothing, kwargs...)
if config !== nothing
Base.depwarn("cudacall with config argument is deprecated, use `@cuda delayed=true` instead", :cudacall)
Base.depwarn("cudacall with config argument is deprecated, use `@cuda launch=false` and instrospect the returned kernel instead", :cudacall)
cudacall(kernel.fun, tt, args...; kwargs..., config(kernel)...)
else
cudacall(kernel.fun, tt, args...; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/gpuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct CuKernelContext <: AbstractKernelContext end

@inline function GPUArrays.launch_heuristic(::CuArrayBackend, f::F, args::Vararg{Any,N};
maximize_blocksize=false) where {F,N}
kernel = @cuda delayed=true f(CuKernelContext(), args...)
kernel = @cuda launch=false f(CuKernelContext(), args...)
if maximize_blocksize
# some kernels benefit (algorithmically) from a large block size
launch_configuration(kernel.fun)
Expand Down
8 changes: 4 additions & 4 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function Base.getindex(xs::AnyCuArray{T}, bools::AnyCuArray{Bool}) where {T}
return
end

kernel = @cuda name="logical_getindex" delayed=true kernel(ys, xs, bools, indices)
kernel = @cuda name="logical_getindex" launch=false kernel(ys, xs, bools, indices)
config = launch_configuration(kernel.fun)
threads = Base.min(length(indices), config.threads)
blocks = cld(length(indices), threads)
Expand Down Expand Up @@ -64,7 +64,7 @@ function Base.findall(bools::AnyCuArray{Bool})
return
end

kernel = @cuda name="findall" delayed=true kernel(ys, bools, indices)
kernel = @cuda name="findall" launch=false kernel(ys, bools, indices)
config = launch_configuration(kernel.fun)
threads = Base.min(length(indices), config.threads)
blocks = cld(length(indices), threads)
Expand Down Expand Up @@ -98,7 +98,7 @@ function Base.findfirst(testf::Function, xs::AnyCuArray)
return
end

kernel = @cuda name="findfirst" delayed=true kernel(y, xs)
kernel = @cuda name="findfirst" launch=false kernel(y, xs)
config = launch_configuration(kernel.fun)
threads = Base.min(length(xs), config.threads)
blocks = cld(length(xs), threads)
Expand Down Expand Up @@ -165,7 +165,7 @@ function findfirstval(vals::AnyCuArray, xs::AnyCuArray)
return
end

kernel = @cuda delayed=true kernel(xs, vals, indices)
kernel = @cuda launch=false kernel(xs, vals, indices)
config = launch_configuration(kernel.fun)
threads = Base.min(length(xs), config.threads)
blocks = cld(length(xs), threads)
Expand Down
2 changes: 1 addition & 1 deletion src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::AnyCuArray{T},
# we might not be able to launch all those threads to reduce each slice in one go.
# that's why each threads also loops across their inputs, processing multiple values
# so that we can span the entire reduction dimension using a single thread block.
kernel = @cuda delayed=true partial_mapreduce_grid(f, op, init, Rreduce, Rother, Val(shuffle), R′, A)
kernel = @cuda launch=false partial_mapreduce_grid(f, op, init, Rreduce, Rother, Val(shuffle), R′, A)
compute_shmem(threads) = shuffle ? 0 : 2*threads*sizeof(T)
kernel_config = launch_configuration(kernel.fun; shmem=compute_shmemcompute_threads)
reduce_threads = compute_threads(kernel_config.threads)
Expand Down
24 changes: 12 additions & 12 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@ dummy() = return
@test_throws MethodError @cuda dummy(1)


@testset "delayed kernel" begin
k = @cuda delayed=true dummy()
k()
k(; threads=1)

CUDA.version(k)
CUDA.memory(k)
CUDA.registers(k)
CUDA.maxthreads(k)
end


@testset "launch configuration" begin
@cuda dummy()

Expand All @@ -33,6 +21,18 @@ end
end


@testset "launch=false" begin
k = @cuda launch=false dummy()
k()
k(; threads=1)

CUDA.version(k)
CUDA.memory(k)
CUDA.registers(k)
CUDA.maxthreads(k)
end


@testset "compilation params" begin
@cuda dummy()

Expand Down
2 changes: 1 addition & 1 deletion test/texture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function fetch_all(texture)
dims = size(texture)
d_out = CuArray{eltype(texture)}(undef, dims...)

kernel = @cuda delayed=true kernel_texture_warp_native(d_out, texture)
kernel = @cuda launch=false kernel_texture_warp_native(d_out, texture)
config = launch_configuration(kernel.fun)

dim_x, dim_y, dim_z = size(texture, 1), size(texture, 2), size(texture, 3)
Expand Down

0 comments on commit 06c0a44

Please sign in to comment.