Skip to content

Commit

Permalink
Use contextual dispatch for device functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 17, 2021
1 parent 42f5562 commit 6f322bd
Show file tree
Hide file tree
Showing 16 changed files with 283 additions and 493 deletions.
29 changes: 27 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.29"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b"
Expand Down Expand Up @@ -77,14 +83,21 @@ version = "6.2.0"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55"
git-tree-sha1 = "b6c3b8e2df6ffe0da0b10e2045ce35a3cf618b8a"
repo-rev = "1ecbe42"
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.10.0"
version = "0.10.1"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JLLWrappers]]
git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.2.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
Expand Down Expand Up @@ -150,6 +163,12 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand Down Expand Up @@ -205,6 +224,12 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

Expand Down
13 changes: 13 additions & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ using BFloat16s

using Memoize

using ExprTools


##

const ci_cache = GPUCompiler.CodeCache()

@static if VERSION >= v"1.7-"
Base.Experimental.@MethodTable(method_table)
else
const method_table = nothing
end


## source code includes

Expand Down
2 changes: 0 additions & 2 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
dims > ndims(input) && return copyto!(output, input)
isempty(inds_t[dims]) && return output

f = cufunc(f)

# iteration domain across the main dimension
Rdim = CartesianIndices((size(input, dims),))

Expand Down
99 changes: 3 additions & 96 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,99 +14,6 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =
Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} =
CuArray{T}(undef, dims)


## replace base functions with libdevice alternatives

cufunc(f) = f
cufunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible

Broadcast.broadcasted(::CuArrayStyle{N}, f, args...) where {N} =
Broadcasted{CuArrayStyle{N}}(cufunc(f), args, nothing)

const device_intrinsics = :[
cos, cospi, sin, sinpi, tan, acos, asin, atan,
cosh, sinh, tanh, acosh, asinh, atanh, angle,
log, log10, log1p, log2, logb, ilogb,
exp, exp2, exp10, expm1, ldexp,
erf, erfinv, erfc, erfcinv, erfcx,
brev, clz, ffs, byte_perm, popc,
isfinite, isinf, isnan, nearbyint,
nextafter, signbit, copysign, abs,
sqrt, rsqrt, cbrt, rcbrt, pow,
ceil, floor, saturate,
lgamma, tgamma,
j0, j1, jn, y0, y1, yn,
normcdf, normcdfinv, hypot,
fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn].args

for f in device_intrinsics
isdefined(Base, f) || continue
@eval cufunc(::typeof(Base.$f)) = $f
end

# broadcast ^

culiteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x)
culiteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x
culiteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x
culiteral_pow(::typeof(^), x::T, ::Val{3}) where {T<:Real} = x * x * x
culiteral_pow(::typeof(^), x::T, ::Val{p}) where {T<:Real,p} = pow(x, Int32(p))

cufunc(::typeof(Base.literal_pow)) = culiteral_pow
cufunc(::typeof(Base.:(^))) = pow

using MacroTools

const _cufuncs = [copy(device_intrinsics); :^]
cufuncs() = (global _cufuncs; _cufuncs)

_cuint(x::Int) = Int32(x)
_cuint(x::Expr) = x.head == :call && x.args[1] == :Int32 && x.args[2] isa Int ? Int32(x.args[2]) : x
_cuint(x) = x

function _cupowliteral(x::Expr)
if x.head == :call && x.args[1] == :(CUDA.cufunc(^)) && x.args[3] isa Int32
num = x.args[3]
if 0 <= num <= 3
sym = gensym(:x)
new_x = Expr(:block, :($sym = $(x.args[2])))

if iszero(num)
push!(new_x.args, :(one($sym)))
else
unroll = Expr(:call, :*)
for x = one(num):num
push!(unroll.args, sym)
end
push!(new_x.args, unroll)
end

x = new_x
end
end
x
end
_cupowliteral(x) = x

function replace_device(ex)
global _cufuncs
MacroTools.postwalk(ex) do x
x = x in _cufuncs ? :(CUDA.cufunc($x)) : x
x = _cuint(x)
x = _cupowliteral(x)
x
end
end

macro cufunc(ex)
global _cufuncs
def = MacroTools.splitdef(ex)
f = def[:name]
def[:name] = Symbol(:cu, f)
def[:body] = replace_device(def[:body])
push!(_cufuncs, f)
quote
$(esc(MacroTools.combinedef(def)))
CUDA.cufunc(::typeof($(esc(f)))) = $(esc(def[:name]))
end
end
# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing)
4 changes: 4 additions & 0 deletions src/compiler/gpucompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module,
job, mod, undefined_fns)
link_libdevice!(mod, job.target.cap, undefined_fns)
end

GPUCompiler.ci_cache(::CUDACompilerJob) = ci_cache

GPUCompiler.method_table(::CUDACompilerJob) = method_table
29 changes: 29 additions & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# wrappers for functionality provided by the CUDA toolkit

const overrides = quote end

macro device_override(ex)
code = quote
$GPUCompiler.@override($method_table, $ex)
end
if VERSION >= v"1.7-"
return esc(code)
else
push!(overrides.args, code)
return
end
end

macro device_function(ex)
ex = macroexpand(__module__, ex)
def = splitdef(ex)

# generate a function that errors
def[:body] = quote
error("This function is not intended for use on the CPU")
end

esc(quote
$(combinedef(def))
@device_override $ex
end)
end

# extensions to the C language
include("intrinsics/memory_shared.jl")
include("intrinsics/indexing.jl")
Expand Down
Loading

0 comments on commit 6f322bd

Please sign in to comment.