diff --git a/lib/cudadrv/context.jl b/lib/cudadrv/context.jl index 9a78aabdd3..be44c0840f 100644 --- a/lib/cudadrv/context.jl +++ b/lib/cudadrv/context.jl @@ -1,9 +1,9 @@ # Context management export - CuPrimaryContext, CuContext, CuCurrentContext, activate, + CuPrimaryContext, CuContext, current_context, activate, unsafe_reset!, isactive, flags, setflags!, - device_synchronize, CuCurrentDevice, + device, device_synchronize ## construction and destruction @@ -67,11 +67,11 @@ mutable struct CuContext end """ - CuCurrentContext() + current_context() Return the current context, or `nothing` if there is no active context. """ - global function CuCurrentContext() + global function current_context() handle_ref = Ref{CUcontext}() cuCtxGetCurrent(handle_ref) if handle_ref[] == C_NULL @@ -81,14 +81,8 @@ mutable struct CuContext end end - """ - CuContext(ptr) - - Identify the context a CUDA memory buffer was allocated in. - """ - function CuContext(x::Union{Ptr,CuPtr}) - new_unique(attribute(CUcontext, x, POINTER_ATTRIBUTE_CONTEXT)) - end + # for outer constructors + global _CuContext(handle::CUcontext) = new_unique(handle) end # the `valid` bit serves two purposes: make sure we don't double-free a context (in case we @@ -206,7 +200,7 @@ does not respect any users of the context, and might make other objects unusable """ function unsafe_release!(ctx::CuContext) if isvalid(ctx) - dev = CuDevice(ctx) + dev = device(ctx) pctx = CuPrimaryContext(dev) if version() >= v"11" cuDevicePrimaryCtxRelease_v2(dev) @@ -283,29 +277,13 @@ end ## context properties """ - CuCurrentDevice() - -Returns the current device, or `nothing` if there is no active device. -""" -function CuCurrentDevice() - device_ref = Ref{CUdevice}() - res = unsafe_cuCtxGetDevice(device_ref) - if res == ERROR_INVALID_CONTEXT - return nothing - elseif res != SUCCESS - throw_api_error(res) - end - return CuDevice(Bool, device_ref[]) -end - -""" - CuDevice(::CuContext) + device(::CuContext) Returns the device for a context. """ -function CuDevice(ctx::CuContext) +function device(ctx::CuContext) push!(CuContext, ctx) - dev = CuCurrentDevice() + dev = current_device() pop!(CuContext) return dev end diff --git a/lib/cudadrv/devices.jl b/lib/cudadrv/devices.jl index 2057642764..0c5b557e94 100644 --- a/lib/cudadrv/devices.jl +++ b/lib/cudadrv/devices.jl @@ -1,31 +1,45 @@ # Device type and auxiliary functions export - CuDevice, name, uuid, totalmem, attribute + CuDevice, current_device, name, uuid, totalmem, attribute - -""" - CuDevice(i::Integer) - -Get a handle to a compute device. -""" struct CuDevice handle::CUdevice - # CuDevice is just an integer, but we need (?) to call cuDeviceGet to make sure this - # integer is valid. to avoid ambiguity, add a bogus argument (cfr. `checkbounds`) - CuDevice(::Type{Bool}, handle::CUdevice) = new(handle) -end + """ + CuDevice(ordinal::Integer) -const DEVICE_CPU = CuDevice(Bool, CUdevice(-1)) -const DEVICE_INVALID = CuDevice(Bool, CUdevice(-2)) + Get a handle to a compute device. + """ + function CuDevice(ordinal::Integer) + device_ref = Ref{CUdevice}() + cuDeviceGet(device_ref, ordinal) + new(device_ref[]) + end -function CuDevice(ordinal::Integer) - device_ref = Ref{CUdevice}() - cuDeviceGet(device_ref, ordinal) - CuDevice(Bool, device_ref[]) + """ + current_device() + + Returns the current device, or `nothing` if there is no active device. + """ + global function current_device() + device_ref = Ref{CUdevice}() + res = unsafe_cuCtxGetDevice(device_ref) + if res == ERROR_INVALID_CONTEXT + return nothing + elseif res != SUCCESS + throw_api_error(res) + end + return _CuDevice(device_ref[]) + end + + # for outer constructors + global _CuDevice(handle::CUdevice) = new(handle) end +const DEVICE_CPU = _CuDevice(CUdevice(-1)) +const DEVICE_INVALID = _CuDevice(CUdevice(-2)) + Base.convert(::Type{CUdevice}, dev::CuDevice) = dev.handle Base.:(==)(a::CuDevice, b::CuDevice) = a.handle == b.handle diff --git a/lib/cudadrv/events.jl b/lib/cudadrv/events.jl index c63d1a2c7b..d2d1377328 100644 --- a/lib/cudadrv/events.jl +++ b/lib/cudadrv/events.jl @@ -18,7 +18,7 @@ mutable struct CuEvent handle_ref = Ref{CUevent}() cuEventCreate(handle_ref, flags) - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], ctx) finalizer(unsafe_destroy!, obj) return obj diff --git a/lib/cudadrv/graph.jl b/lib/cudadrv/graph.jl index ef3f98a51d..e19b2542f5 100644 --- a/lib/cudadrv/graph.jl +++ b/lib/cudadrv/graph.jl @@ -15,7 +15,7 @@ mutable struct CuGraph handle_ref = Ref{CUgraph}() cuGraphCreate(handle_ref, flags) - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], ctx) finalizer(unsafe_destroy!, obj) return obj @@ -39,7 +39,7 @@ mutable struct CuGraph global function capture(f::Function; flags=STREAM_CAPTURE_MODE_GLOBAL, throw_error::Bool=true) cuStreamBeginCapture_v2(stream(), flags) - ctx = CuCurrentContext() + ctx = current_context() obj = nothing try f() @@ -93,7 +93,7 @@ mutable struct CuGraphExec # TODO: how to use these? end - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], graph, ctx) finalizer(unsafe_destroy!, obj) return obj diff --git a/lib/cudadrv/memory.jl b/lib/cudadrv/memory.jl index 5e8941d3a9..7b20bdef33 100644 --- a/lib/cudadrv/memory.jl +++ b/lib/cudadrv/memory.jl @@ -738,13 +738,27 @@ end # some common attributes +""" + context(ptr) + +Identify the context a CUDA memory buffer was allocated in. +""" +context(ptr::Union{Ptr,CuPtr}) = + _CuContext(attribute(CUcontext, ptr, POINTER_ATTRIBUTE_CONTEXT)) + +""" + device(ptr) + +Identify the device a CUDA memory buffer was allocated on. +""" +device(x::Union{Ptr,CuPtr}) = + CuDevice(convert(Int, attribute(Cuint, x, POINTER_ATTRIBUTE_DEVICE_ORDINAL))) + @enum_without_prefix CUmemorytype CU_ memory_type(x) = CUmemorytype(attribute(Cuint, x, POINTER_ATTRIBUTE_MEMORY_TYPE)) is_managed(x) = convert(Bool, attribute(Cuint, x, POINTER_ATTRIBUTE_IS_MANAGED)) -CuDevice(x::Union{Ptr,CuPtr}) = CuDevice(convert(Int, attribute(Cuint, x, POINTER_ATTRIBUTE_DEVICE_ORDINAL))) - function is_pinned(ptr::Ptr) # unpinned memory makes cuPointerGetAttribute return ERROR_INVALID_VALUE; but instead of # calling `memory_type` with an expensive try/catch we perform low-level API calls. diff --git a/lib/cudadrv/module.jl b/lib/cudadrv/module.jl index 38a5879aad..ccb1f2616b 100644 --- a/lib/cudadrv/module.jl +++ b/lib/cudadrv/module.jl @@ -70,7 +70,7 @@ mutable struct CuModule end end - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], ctx) finalizer(unsafe_unload!, obj) return obj diff --git a/lib/cudadrv/module/linker.jl b/lib/cudadrv/module/linker.jl index 5090c8e904..c178252af3 100644 --- a/lib/cudadrv/module/linker.jl +++ b/lib/cudadrv/module/linker.jl @@ -36,7 +36,7 @@ mutable struct CuLink cuLinkCreate_v2(length(optionKeys), optionKeys, optionVals, handle_ref) - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], ctx, options, optionKeys, optionVals) finalizer(unsafe_destroy!, obj) return obj diff --git a/lib/cudadrv/occupancy.jl b/lib/cudadrv/occupancy.jl index 96b88242e9..223c86d5c2 100644 --- a/lib/cudadrv/occupancy.jl +++ b/lib/cudadrv/occupancy.jl @@ -25,7 +25,7 @@ function occupancy(fun::CuFunction, threads::Integer; shmem::Integer=0) mod = fun.mod ctx = mod.ctx - dev = CuDevice(ctx) + dev = device(ctx) threads_per_sm = attribute(dev, DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR) warp_size = attribute(dev, DEVICE_ATTRIBUTE_WARP_SIZE) diff --git a/lib/cudadrv/pool.jl b/lib/cudadrv/pool.jl index 968227c295..b5d1b7d32d 100644 --- a/lib/cudadrv/pool.jl +++ b/lib/cudadrv/pool.jl @@ -25,7 +25,7 @@ mutable struct CuMemoryPool handle_ref = Ref{CUmemoryPool}() cuMemPoolCreate(handle_ref, props) - ctx = CuCurrentContext() + ctx = current_context() obj = new(handle_ref[], ctx) finalizer(unsafe_destroy!, obj) return obj @@ -35,7 +35,7 @@ mutable struct CuMemoryPool handle_ref = Ref{CUmemoryPool}() cuDeviceGetDefaultMemPool(handle_ref, dev) - ctx = CuCurrentContext() + ctx = current_context() new(handle_ref[], ctx) end @@ -43,7 +43,7 @@ mutable struct CuMemoryPool handle_ref = Ref{CUmemoryPool}() cuDeviceGetMemPool(handle_ref, dev) - ctx = CuCurrentContext()::CuContext + ctx = current_context()::CuContext new(handle_ref[], ctx) end end diff --git a/lib/cudadrv/stream.jl b/lib/cudadrv/stream.jl index 9361442dcb..56508b5183 100644 --- a/lib/cudadrv/stream.jl +++ b/lib/cudadrv/stream.jl @@ -24,7 +24,7 @@ mutable struct CuStream cuStreamCreateWithPriority(handle_ref, flags, priority) end - ctx = CuCurrentContext()::CuContext + ctx = current_context()::CuContext obj = new(handle_ref[], ctx) finalizer(unsafe_destroy!, obj) return obj diff --git a/src/array.jl b/src/array.jl index 6a224021ee..43bb567157 100644 --- a/src/array.jl +++ b/src/array.jl @@ -225,6 +225,11 @@ Base.elsize(::Type{<:CuArray{T}}) where {T} = sizeof(T) Base.size(x::CuArray) = x.dims Base.sizeof(x::CuArray) = Base.elsize(x) * length(x) +function device(A::CuArray) + A.storage === nothing && throw(UndefRefError()) + return device(A.storage.ctx) +end + ## derived types diff --git a/src/compiler/exceptions.jl b/src/compiler/exceptions.jl index b00e330e0f..10e1967a46 100644 --- a/src/compiler/exceptions.jl +++ b/src/compiler/exceptions.jl @@ -30,7 +30,7 @@ function check_exceptions() flag = unsafe_load(ptr) if flag != 0 unsafe_store!(ptr, 0) - dev = CuDevice(ctx) + dev = device(ctx) throw(KernelException(dev)) end end diff --git a/src/deprecated.jl b/src/deprecated.jl index 07e2a06c87..26ccb0c2fc 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +1,7 @@ # Deprecated functionality + +@deprecate CuDevice(ctx::CuContext) device(ctx) +@deprecate CuCurrentDevice() current_device() +@deprecate CuCurrentContext() current_context() +@deprecate CuContext(ptr::Union{Ptr,CuPtr}) context(ptr) +@deprecate CuDevice(ptr::Union{Ptr,CuPtr}) device(ptr) diff --git a/src/state.jl b/src/state.jl index 9157377682..dc6d3c1711 100644 --- a/src/state.jl +++ b/src/state.jl @@ -87,9 +87,9 @@ end @inline function prepare_cuda_state() state = task_local_state!() - # NOTE: CuCurrentContext() is too slow to use here (taking a lock, accessing a dict) + # NOTE: current_context() is too slow to use here (taking a lock, accessing a dict) # so we use the raw handle. is that safe though, when we reset the device? - #ctx = CuCurrentContext() + #ctx = current_context() ctx = Ref{CUcontext}() cuCtxGetCurrent(ctx) if ctx[] != state.context.handle @@ -115,7 +115,7 @@ end context()::CuContext Get or create a CUDA context for the current thread (as opposed to -`CuCurrentContext` which may return `nothing` if there is no context bound to the +`current_context` which may return `nothing` if there is no context bound to the current thread). """ function context() @@ -133,7 +133,7 @@ Note that the contexts used with this call should be previously acquired by call function context!(ctx::CuContext) activate(ctx) # we generally only apply CUDA state lazily, i.e. in `prepare_cuda_state`, # but we need to do so early here to be able to get the context's device. - dev = CuCurrentDevice()::CuDevice + dev = current_device()::CuDevice # switch contexts state = task_local_state() @@ -201,7 +201,7 @@ end device()::CuDevice Get the CUDA device for the current thread, similar to how [`context()`](@ref) works -compared to [`CuCurrentContext()`](@ref). +compared to [`current_context()`](@ref). """ function device() task_local_state!().device diff --git a/test/array.jl b/test/array.jl index 8287742ca3..b849da288b 100644 --- a/test/array.jl +++ b/test/array.jl @@ -3,6 +3,7 @@ import Adapt @testset "constructors" begin xs = CuArray{Int}(undef, 2, 3) + @test device(xs) == device() @test collect(CuArray([1 2; 3 4])) == [1 2; 3 4] @test collect(cu[1, 2, 3]) == [1, 2, 3] @test collect(cu([1, 2, 3])) == [1, 2, 3] @@ -51,7 +52,7 @@ import Adapt gpu_ptr = convert(CuPtr{Int}, buf) gpu_arr = Base.unsafe_wrap(CuArray, gpu_ptr, 1) gpu_arr .= 42 - + synchronize() cpu_ptr = convert(Ptr{Int}, buf) diff --git a/test/cudadrv.jl b/test/cudadrv.jl index add6eeda44..4079cb96d5 100644 --- a/test/cudadrv.jl +++ b/test/cudadrv.jl @@ -1,31 +1,31 @@ @testset "context" begin -ctx = CuCurrentContext() -dev = CuCurrentDevice() +ctx = current_context() +dev = current_device() synchronize(ctx) let ctx2 = CuContext(dev) - @test ctx2 == CuCurrentContext() # ctor implicitly pushes + @test ctx2 == current_context() # ctor implicitly pushes activate(ctx) - @test ctx == CuCurrentContext() + @test ctx == current_context() - @test CuDevice(ctx2) == dev + @test device(ctx2) == dev CUDA.unsafe_destroy!(ctx2) end let global_ctx2 = nothing CuContext(dev) do ctx2 - @test ctx2 == CuCurrentContext() + @test ctx2 == current_context() @test ctx != ctx2 global_ctx2 = ctx2 end @test !CUDA.isvalid(global_ctx2) - @test ctx == CuCurrentContext() + @test ctx == current_context() - @test CuDevice(ctx) == dev - @test CuCurrentDevice() == dev + @test device(ctx) == dev + @test current_device() == dev device_synchronize() end @@ -453,7 +453,7 @@ for srcTy in [Mem.Device, Mem.Host, Mem.Unified], end # test device with context in which pointer was allocated. - @test CuDevice(typed_pointer(src, T)) == device() + @test device(typed_pointer(src, T)) == device() if !CUDA.has_stream_ordered(device()) # NVIDIA bug #3319609 @test CuContext(typed_pointer(src, T)) == context() diff --git a/test/initialization.jl b/test/initialization.jl index def73923c7..6784bf4288 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -2,20 +2,20 @@ @test has_cuda_gpu(true) # the API shouldn't have been initialized -@test CuCurrentContext() == nothing -@test CuCurrentDevice() == nothing +@test current_context() == nothing +@test current_device() == nothing ctx = context() dev = device() # querying Julia's side of things shouldn't cause initialization -@test CuCurrentContext() == nothing -@test CuCurrentDevice() == nothing +@test current_context() == nothing +@test current_device() == nothing # now cause initialization a = CuArray([42]) -@test CuCurrentContext() == ctx -@test CuCurrentDevice() == dev +@test current_context() == ctx +@test current_device() == dev # ... on a different task task = @async begin