Skip to content

Commit

Permalink
Rework device and context getters.
Browse files Browse the repository at this point in the history
Use lower-case device and context, preserve CuContext
and CuDevice for constructors.
  • Loading branch information
maleadt committed Sep 30, 2021
1 parent 5b74388 commit 0d4878a
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 85 deletions.
42 changes: 10 additions & 32 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Context management

export
CuPrimaryContext, CuContext, CuCurrentContext, activate,
CuPrimaryContext, CuContext, current_context, activate,
unsafe_reset!, isactive, flags, setflags!,
device_synchronize, CuCurrentDevice,
device, device_synchronize


## construction and destruction
Expand Down Expand Up @@ -67,11 +67,11 @@ mutable struct CuContext
end

"""
CuCurrentContext()
current_context()
Return the current context, or `nothing` if there is no active context.
"""
global function CuCurrentContext()
global function current_context()
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
if handle_ref[] == C_NULL
Expand All @@ -81,14 +81,8 @@ mutable struct CuContext
end
end

"""
CuContext(ptr)
Identify the context a CUDA memory buffer was allocated in.
"""
function CuContext(x::Union{Ptr,CuPtr})
new_unique(attribute(CUcontext, x, POINTER_ATTRIBUTE_CONTEXT))
end
# for outer constructors
global _CuContext(handle::CUcontext) = new_unique(handle)
end

# the `valid` bit serves two purposes: make sure we don't double-free a context (in case we
Expand Down Expand Up @@ -206,7 +200,7 @@ does not respect any users of the context, and might make other objects unusable
"""
function unsafe_release!(ctx::CuContext)
if isvalid(ctx)
dev = CuDevice(ctx)
dev = device(ctx)
pctx = CuPrimaryContext(dev)
if version() >= v"11"
cuDevicePrimaryCtxRelease_v2(dev)
Expand Down Expand Up @@ -283,29 +277,13 @@ end
## context properties

"""
CuCurrentDevice()
Returns the current device, or `nothing` if there is no active device.
"""
function CuCurrentDevice()
device_ref = Ref{CUdevice}()
res = unsafe_cuCtxGetDevice(device_ref)
if res == ERROR_INVALID_CONTEXT
return nothing
elseif res != SUCCESS
throw_api_error(res)
end
return CuDevice(Bool, device_ref[])
end

"""
CuDevice(::CuContext)
device(::CuContext)
Returns the device for a context.
"""
function CuDevice(ctx::CuContext)
function device(ctx::CuContext)
push!(CuContext, ctx)
dev = CuCurrentDevice()
dev = current_device()
pop!(CuContext)
return dev
end
Expand Down
48 changes: 31 additions & 17 deletions lib/cudadrv/devices.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
# Device type and auxiliary functions

export
CuDevice, name, uuid, totalmem, attribute
CuDevice, current_device, name, uuid, totalmem, attribute


"""
CuDevice(i::Integer)
Get a handle to a compute device.
"""
struct CuDevice
handle::CUdevice

# CuDevice is just an integer, but we need (?) to call cuDeviceGet to make sure this
# integer is valid. to avoid ambiguity, add a bogus argument (cfr. `checkbounds`)
CuDevice(::Type{Bool}, handle::CUdevice) = new(handle)
end
"""
CuDevice(ordinal::Integer)
const DEVICE_CPU = CuDevice(Bool, CUdevice(-1))
const DEVICE_INVALID = CuDevice(Bool, CUdevice(-2))
Get a handle to a compute device.
"""
function CuDevice(ordinal::Integer)
device_ref = Ref{CUdevice}()
cuDeviceGet(device_ref, ordinal)
new(device_ref[])
end

function CuDevice(ordinal::Integer)
device_ref = Ref{CUdevice}()
cuDeviceGet(device_ref, ordinal)
CuDevice(Bool, device_ref[])
"""
current_device()
Returns the current device, or `nothing` if there is no active device.
"""
global function current_device()
device_ref = Ref{CUdevice}()
res = unsafe_cuCtxGetDevice(device_ref)
if res == ERROR_INVALID_CONTEXT
return nothing
elseif res != SUCCESS
throw_api_error(res)
end
return _CuDevice(device_ref[])
end

# for outer constructors
global _CuDevice(handle::CUdevice) = new(handle)
end

const DEVICE_CPU = _CuDevice(CUdevice(-1))
const DEVICE_INVALID = _CuDevice(CUdevice(-2))

Base.convert(::Type{CUdevice}, dev::CuDevice) = dev.handle

Base.:(==)(a::CuDevice, b::CuDevice) = a.handle == b.handle
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mutable struct CuEvent
handle_ref = Ref{CUevent}()
cuEventCreate(handle_ref, flags)

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], ctx)
finalizer(unsafe_destroy!, obj)
return obj
Expand Down
6 changes: 3 additions & 3 deletions lib/cudadrv/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mutable struct CuGraph
handle_ref = Ref{CUgraph}()
cuGraphCreate(handle_ref, flags)

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], ctx)
finalizer(unsafe_destroy!, obj)
return obj
Expand All @@ -39,7 +39,7 @@ mutable struct CuGraph
global function capture(f::Function; flags=STREAM_CAPTURE_MODE_GLOBAL, throw_error::Bool=true)
cuStreamBeginCapture_v2(stream(), flags)

ctx = CuCurrentContext()
ctx = current_context()
obj = nothing
try
f()
Expand Down Expand Up @@ -93,7 +93,7 @@ mutable struct CuGraphExec
# TODO: how to use these?
end

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], graph, ctx)
finalizer(unsafe_destroy!, obj)
return obj
Expand Down
18 changes: 16 additions & 2 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,13 +738,27 @@ end

# some common attributes

"""
context(ptr)
Identify the context a CUDA memory buffer was allocated in.
"""
context(ptr::Union{Ptr,CuPtr}) =
_CuContext(attribute(CUcontext, ptr, POINTER_ATTRIBUTE_CONTEXT))

"""
device(ptr)
Identify the device a CUDA memory buffer was allocated on.
"""
device(x::Union{Ptr,CuPtr}) =
CuDevice(convert(Int, attribute(Cuint, x, POINTER_ATTRIBUTE_DEVICE_ORDINAL)))

@enum_without_prefix CUmemorytype CU_
memory_type(x) = CUmemorytype(attribute(Cuint, x, POINTER_ATTRIBUTE_MEMORY_TYPE))

is_managed(x) = convert(Bool, attribute(Cuint, x, POINTER_ATTRIBUTE_IS_MANAGED))

CuDevice(x::Union{Ptr,CuPtr}) = CuDevice(convert(Int, attribute(Cuint, x, POINTER_ATTRIBUTE_DEVICE_ORDINAL)))

function is_pinned(ptr::Ptr)
# unpinned memory makes cuPointerGetAttribute return ERROR_INVALID_VALUE; but instead of
# calling `memory_type` with an expensive try/catch we perform low-level API calls.
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ mutable struct CuModule
end
end

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], ctx)
finalizer(unsafe_unload!, obj)
return obj
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module/linker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ mutable struct CuLink

cuLinkCreate_v2(length(optionKeys), optionKeys, optionVals, handle_ref)

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], ctx, options, optionKeys, optionVals)
finalizer(unsafe_destroy!, obj)
return obj
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/occupancy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function occupancy(fun::CuFunction, threads::Integer; shmem::Integer=0)

mod = fun.mod
ctx = mod.ctx
dev = CuDevice(ctx)
dev = device(ctx)

threads_per_sm = attribute(dev, DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR)
warp_size = attribute(dev, DEVICE_ATTRIBUTE_WARP_SIZE)
Expand Down
6 changes: 3 additions & 3 deletions lib/cudadrv/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mutable struct CuMemoryPool
handle_ref = Ref{CUmemoryPool}()
cuMemPoolCreate(handle_ref, props)

ctx = CuCurrentContext()
ctx = current_context()
obj = new(handle_ref[], ctx)
finalizer(unsafe_destroy!, obj)
return obj
Expand All @@ -35,15 +35,15 @@ mutable struct CuMemoryPool
handle_ref = Ref{CUmemoryPool}()
cuDeviceGetDefaultMemPool(handle_ref, dev)

ctx = CuCurrentContext()
ctx = current_context()
new(handle_ref[], ctx)
end

global function memory_pool(dev::CuDevice)
handle_ref = Ref{CUmemoryPool}()
cuDeviceGetMemPool(handle_ref, dev)

ctx = CuCurrentContext()::CuContext
ctx = current_context()::CuContext
new(handle_ref[], ctx)
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mutable struct CuStream
cuStreamCreateWithPriority(handle_ref, flags, priority)
end

ctx = CuCurrentContext()::CuContext
ctx = current_context()::CuContext
obj = new(handle_ref[], ctx)
finalizer(unsafe_destroy!, obj)
return obj
Expand Down
5 changes: 5 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ Base.elsize(::Type{<:CuArray{T}}) where {T} = sizeof(T)
Base.size(x::CuArray) = x.dims
Base.sizeof(x::CuArray) = Base.elsize(x) * length(x)

function device(A::CuArray)
A.storage === nothing && throw(UndefRefError())
return device(A.storage.ctx)
end


## derived types

Expand Down
2 changes: 1 addition & 1 deletion src/compiler/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function check_exceptions()
flag = unsafe_load(ptr)
if flag != 0
unsafe_store!(ptr, 0)
dev = CuDevice(ctx)
dev = device(ctx)
throw(KernelException(dev))
end
end
Expand Down
6 changes: 6 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# Deprecated functionality

@deprecate CuDevice(ctx::CuContext) device(ctx)
@deprecate CuCurrentDevice() current_device()
@deprecate CuCurrentContext() current_context()
@deprecate CuContext(ptr::Union{Ptr,CuPtr}) context(ptr)
@deprecate CuDevice(ptr::Union{Ptr,CuPtr}) device(ptr)
10 changes: 5 additions & 5 deletions src/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ end
@inline function prepare_cuda_state()
state = task_local_state!()

# NOTE: CuCurrentContext() is too slow to use here (taking a lock, accessing a dict)
# NOTE: current_context() is too slow to use here (taking a lock, accessing a dict)
# so we use the raw handle. is that safe though, when we reset the device?
#ctx = CuCurrentContext()
#ctx = current_context()
ctx = Ref{CUcontext}()
cuCtxGetCurrent(ctx)
if ctx[] != state.context.handle
Expand All @@ -115,7 +115,7 @@ end
context()::CuContext
Get or create a CUDA context for the current thread (as opposed to
`CuCurrentContext` which may return `nothing` if there is no context bound to the
`current_context` which may return `nothing` if there is no context bound to the
current thread).
"""
function context()
Expand All @@ -133,7 +133,7 @@ Note that the contexts used with this call should be previously acquired by call
function context!(ctx::CuContext)
activate(ctx) # we generally only apply CUDA state lazily, i.e. in `prepare_cuda_state`,
# but we need to do so early here to be able to get the context's device.
dev = CuCurrentDevice()::CuDevice
dev = current_device()::CuDevice

# switch contexts
state = task_local_state()
Expand Down Expand Up @@ -201,7 +201,7 @@ end
device()::CuDevice
Get the CUDA device for the current thread, similar to how [`context()`](@ref) works
compared to [`CuCurrentContext()`](@ref).
compared to [`current_context()`](@ref).
"""
function device()
task_local_state!().device
Expand Down
3 changes: 2 additions & 1 deletion test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Adapt

@testset "constructors" begin
xs = CuArray{Int}(undef, 2, 3)
@test device(xs) == device()
@test collect(CuArray([1 2; 3 4])) == [1 2; 3 4]
@test collect(cu[1, 2, 3]) == [1, 2, 3]
@test collect(cu([1, 2, 3])) == [1, 2, 3]
Expand Down Expand Up @@ -51,7 +52,7 @@ import Adapt
gpu_ptr = convert(CuPtr{Int}, buf)
gpu_arr = Base.unsafe_wrap(CuArray, gpu_ptr, 1)
gpu_arr .= 42

synchronize()

cpu_ptr = convert(Ptr{Int}, buf)
Expand Down
Loading

0 comments on commit 0d4878a

Please sign in to comment.