Skip to content

Commit

Permalink
Avoid allocating diagonal zeros in sparse(::Diagonal) (JuliaLang#42577
Browse files Browse the repository at this point in the history
)

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
mcognetta and dkarrasch authored Nov 21, 2021
1 parent e9430c9 commit a4f7f2d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
27 changes: 26 additions & 1 deletion stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,33 @@ end
SparseMatrixCSC(D::Diagonal{Tv}) where Tv = SparseMatrixCSC{Tv,Int}(D)
function SparseMatrixCSC{Tv,Ti}(D::Diagonal) where {Tv,Ti}
m = length(D.diag)
return SparseMatrixCSC(m, m, Vector(Ti(1):Ti(m+1)), Vector(Ti(1):Ti(m)), Vector{Tv}(D.diag))
m == 0 && return SparseMatrixCSC{Tv,Ti}(zeros(Tv, 0, 0))

nz = count(!iszero, D.diag)
nz_counter = 1

rowval = Vector{Ti}(undef, nz)
nzval = Vector{Tv}(undef, nz)

nz == 0 && return SparseMatrixCSC{Tv,Ti}(m, m, ones(Ti, m+1), rowval, nzval)

colptr = Vector{Ti}(undef, m+1)

@inbounds for i=1:m
if !iszero(D.diag[i])
colptr[i] = nz_counter
rowval[nz_counter] = i
nzval[nz_counter] = D.diag[i]
nz_counter += 1
else
colptr[i] = nz_counter
end
end
colptr[end] = nz_counter

return SparseMatrixCSC{Tv,Ti}(m, m, colptr, rowval, nzval)
end

SparseMatrixCSC(M::AbstractMatrix{Tv}) where {Tv} = SparseMatrixCSC{Tv,Int}(M)
SparseMatrixCSC{Tv}(M::AbstractMatrix{Tv}) where {Tv} = SparseMatrixCSC{Tv,Int}(M)
function SparseMatrixCSC{Tv,Ti}(M::AbstractMatrix) where {Tv,Ti}
Expand Down
8 changes: 8 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,14 @@ end
@test sparse(s) == s
end

@testset "avoid allocation for zeros in diagonal" begin
x = [1, 0, 0, 5, 0]
d = Diagonal(x)
s = sparse(d)
@test s == d
@test nnz(s) == 2
end

@testset "error conditions for reshape, and dropdims" begin
local A = sprand(Bool, 5, 5, 0.2)
@test_throws DimensionMismatch reshape(A,(20, 2))
Expand Down

0 comments on commit a4f7f2d

Please sign in to comment.