Skip to content

Commit

Permalink
add support for half-precision gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
bjarthur authored and maleadt committed Aug 5, 2021
1 parent 5710e25 commit 381ac17
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 83 deletions.
24 changes: 24 additions & 0 deletions lib/cublas/libcublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,17 @@ end
handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc)
end

@checked function cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda,
Barray, ldb, beta, Carray, ldc, batchCount)
initialize_api()
ccall((:cublasHgemmBatched, libcublas()), cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint,
Cint, RefOrCuRef{Float16}, CuPtr{Ptr{Float16}}, Cint, CuPtr{Ptr{Float16}},
Cint, RefOrCuRef{Float16}, CuPtr{Ptr{Float16}}, Cint, Cint),
handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta,
Carray, ldc, batchCount)
end

@checked function cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda,
Barray, ldb, beta, Carray, ldc, batchCount)
initialize_api()
Expand Down Expand Up @@ -1885,6 +1896,19 @@ end
computeType, algo)
end

@checked function cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda,
strideA, B, ldb, strideB, beta, C, ldc,
strideC, batchCount)
initialize_api()
ccall((:cublasHgemmStridedBatched, libcublas()), cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint,
Cint, RefOrCuRef{Float16}, CuPtr{Float16}, Cint, Clonglong,
CuPtr{Float16}, Cint, Clonglong, RefOrCuRef{Float16}, CuPtr{Float16},
Cint, Clonglong, Cint),
handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb,
strideB, beta, C, ldc, strideC, batchCount)
end

@checked function cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda,
strideA, B, ldb, strideB, beta, C, ldc,
strideC, batchCount)
Expand Down
2 changes: 2 additions & 0 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,7 @@ end
for (fname, elty) in
((:cublasDgemmBatched,:Float64),
(:cublasSgemmBatched,:Float32),
(:cublasHgemmBatched,:Float16),
(:cublasZgemmBatched,:ComplexF64),
(:cublasCgemmBatched,:ComplexF32))
@eval begin
Expand Down Expand Up @@ -985,6 +986,7 @@ end
for (fname, elty) in
((:cublasDgemmStridedBatched,:Float64),
(:cublasSgemmStridedBatched,:Float32),
(:cublasHgemmStridedBatched,:Float16),
(:cublasZgemmStridedBatched,:ComplexF64),
(:cublasCgemmStridedBatched,:ComplexF32))
@eval begin
Expand Down
175 changes: 92 additions & 83 deletions test/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,89 +650,6 @@ end
@test C h_C
@test C h_C2
end
# generate matrices
bA = [rand(elty,m,k) for i in 1:10]
bB = [rand(elty,k,n) for i in 1:10]
bC = [rand(elty,m,n) for i in 1:10]
# move to device
bd_A = CuArray{elty, 2}[]
bd_B = CuArray{elty, 2}[]
bd_C = CuArray{elty, 2}[]
bd_bad = CuArray{elty, 2}[]
for i in 1:length(bA)
push!(bd_A,CuArray(bA[i]))
push!(bd_B,CuArray(bB[i]))
push!(bd_C,CuArray(bC[i]))
if i < length(bA) - 2
push!(bd_bad,CuArray(bC[i]))
end
end
@testset "gemm_batched!" begin
# C = (alpha*A)*B + beta*C
CUBLAS.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C)
for i in 1:length(bd_C)
bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i]
h_C = Array(bd_C[i])
#compare
@test bC[i] h_C
end
@test_throws DimensionMismatch CUBLAS.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
end

@testset "gemm_batched" begin
bd_C = CUBLAS.gemm_batched('N','N',bd_A,bd_B)
for i in 1:length(bA)
bC = bA[i]*bB[i]
h_C = Array(bd_C[i])
@test bC h_C
end
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
end

nbatch = 10
bA = rand(elty, m, k, nbatch)
bB = rand(elty, k, n, nbatch)
bC = rand(elty, m, n, nbatch)
bbad = rand(elty, m+1, n+1, nbatch)
# move to device
bd_A = CuArray{elty, 3}(bA)
bd_B = CuArray{elty, 3}(bB)
bd_C = CuArray{elty, 3}(bC)
bd_bad = CuArray{elty, 3}(bbad)
@testset "gemm_strided_batched!" begin
CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_C)
for i in 1:nbatch
bC[:, :, i] = (alpha * bA[:, :, i]) * bB[:, :, i] + beta * bC[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemm_strided_batched" begin
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B)

for i in 1:nbatch
bC[:, :, i] = bA[:, :, i] * bB[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
# generate matrices
bA = rand(elty, k, m, nbatch)
bB = rand(elty, k, n, nbatch)
bC = zeros(elty, m, n, nbatch)
# move to device
bd_A = CuArray{elty, 3}(bA)
bd_B = CuArray{elty, 3}(bB)

bd_C = CUBLAS.gemm_strided_batched('T', 'N', bd_A, bd_B)
for i in 1:nbatch
bC[:, :, i] = transpose(bA[:, :, i]) * bB[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched('N', 'N', alpha, bd_A, bd_bad)
end

B = rand(elty,m,n)
C = rand(elty,m,n)
Expand Down Expand Up @@ -1460,6 +1377,98 @@ end
end
end

@testset for elty in [Float16, Float32, Float64, ComplexF32, ComplexF64]
elty == Float16 && capability(device()) < v"5.3" && continue

alpha = rand(elty)
beta = rand(elty)
# generate matrices
bA = [rand(elty,m,k) for i in 1:10]
bB = [rand(elty,k,n) for i in 1:10]
bC = [rand(elty,m,n) for i in 1:10]
# move to device
bd_A = CuArray{elty, 2}[]
bd_B = CuArray{elty, 2}[]
bd_C = CuArray{elty, 2}[]
bd_bad = CuArray{elty, 2}[]
for i in 1:length(bA)
push!(bd_A,CuArray(bA[i]))
push!(bd_B,CuArray(bB[i]))
push!(bd_C,CuArray(bC[i]))
if i < length(bA) - 2
push!(bd_bad,CuArray(bC[i]))
end
end

@testset "gemm_batched!" begin
# C = (alpha*A)*B + beta*C
CUBLAS.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C)
for i in 1:length(bd_C)
bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i]
h_C = Array(bd_C[i])
#compare
@test bC[i] h_C
end
@test_throws DimensionMismatch CUBLAS.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
end

@testset "gemm_batched" begin
bd_C = CUBLAS.gemm_batched('N','N',bd_A,bd_B)
for i in 1:length(bA)
bC = bA[i]*bB[i]
h_C = Array(bd_C[i])
@test bC h_C
end
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
end

nbatch = 10
bA = rand(elty, m, k, nbatch)
bB = rand(elty, k, n, nbatch)
bC = rand(elty, m, n, nbatch)
bbad = rand(elty, m+1, n+1, nbatch)
# move to device
bd_A = CuArray{elty, 3}(bA)
bd_B = CuArray{elty, 3}(bB)
bd_C = CuArray{elty, 3}(bC)
bd_bad = CuArray{elty, 3}(bbad)

@testset "gemm_strided_batched!" begin
CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_C)
for i in 1:nbatch
bC[:, :, i] = (alpha * bA[:, :, i]) * bB[:, :, i] + beta * bC[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemm_strided_batched" begin
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B)

for i in 1:nbatch
bC[:, :, i] = bA[:, :, i] * bB[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
# generate matrices
bA = rand(elty, k, m, nbatch)
bB = rand(elty, k, n, nbatch)
bC = zeros(elty, m, n, nbatch)
# move to device
bd_A = CuArray{elty, 3}(bA)
bd_B = CuArray{elty, 3}(bB)

bd_C = CUBLAS.gemm_strided_batched('T', 'N', bd_A, bd_B)
for i in 1:nbatch
bC[:, :, i] = transpose(bA[:, :, i]) * bB[:, :, i]
end
h_C = Array(bd_C)
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched('N', 'N', alpha, bd_A, bd_bad)
end
end

@testset "mixed-precision matmul" begin
m,k,n = 4,4,4
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},
Expand Down

0 comments on commit 381ac17

Please sign in to comment.