Skip to content

Commit

Permalink
More tweaks for sparse scale/scale!
Browse files Browse the repository at this point in the history
- Implement scalar scale/scale! variants
- Copy `colptr` and `rowval` from the source to the destination instead
  of assigning them. This costs more memory, but otherwise
  adding new values to the scaled matrix would corrupt the original and
  vice-versa.
  • Loading branch information
simonster committed Jan 24, 2015
1 parent f36625b commit 7d2b079
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
58 changes: 40 additions & 18 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,42 +614,64 @@ inv(A::SparseMatrixCSC) = error("The inverse of a sparse matrix can often be den

## scale methods

# Copy colptr and rowval from one SparseMatrix to another
function copyinds!(C::SparseMatrixCSC, A::SparseMatrixCSC)
if C.colptr !== A.colptr
resize!(C.colptr, length(A.colptr))
copy!(C.colptr, A.colptr)
end
if C.rowval !== A.rowval
resize!(C.rowval, length(A.rowval))
copy!(C.rowval, A.rowval)
end
end

# multiply by diagonal matrix as vector
function scale!{Tv,Ti}(C::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC, b::Vector)
function scale!(C::SparseMatrixCSC, A::SparseMatrixCSC, b::Vector)
m, n = size(A)
(n==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
numnz = nnz(A)
if C !== A
C.colptr = convert(Array{Ti}, A.colptr)
C.rowval = convert(Array{Ti}, A.rowval)
C.nzval = Array(Tv, numnz)
end
copyinds!(C, A)
Cnzval = C.nzval
Anzval = A.nzval
resize!(Cnzval, length(Anzval))
for col = 1:n, p = A.colptr[col]:(A.colptr[col+1]-1)
C.nzval[p] = A.nzval[p] * b[col]
@inbounds Cnzval[p] = Anzval[p] * b[col]
end
C
end

function scale!{Tv,Ti}(C::SparseMatrixCSC{Tv,Ti}, b::Vector, A::SparseMatrixCSC)
function scale!(C::SparseMatrixCSC, b::Vector, A::SparseMatrixCSC)
m, n = size(A)
(m==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
numnz = nnz(A)
if C !== A
C.colptr = convert(Array{Ti}, A.colptr)
C.rowval = convert(Array{Ti}, A.rowval)
C.nzval = Array(Tv, numnz)
end
copyinds!(C, A)
Cnzval = C.nzval
Anzval = A.nzval
Arowval = A.rowval
resize!(Cnzval, length(Anzval))
for col = 1:n, p = A.colptr[col]:(A.colptr[col+1]-1)
C.nzval[p] = A.nzval[p] * b[A.rowval[p]]
@inbounds Cnzval[p] = Anzval[p] * b[Arowval[p]]
end
C
end

function scale!(C::SparseMatrixCSC, A::SparseMatrixCSC, b::Number)
size(A)==size(C) || throw(DimensionMismatch())
copyinds!(C, A)
resize!(C.nzval, length(A.nzval))
scale!(C.nzval, A.nzval, b)
C
end

scale!(C::SparseMatrixCSC, b::Number, A::SparseMatrixCSC) = scale!(C, A, b)

scale!(A::SparseMatrixCSC, b::Number) = (scale!(A.nzval, b); A)
scale!(b::Number, A::SparseMatrixCSC) = (scale!(b, A.nzval); A)

scale{Tv,Ti,T}(A::SparseMatrixCSC{Tv,Ti}, b::Vector{T}) =
scale!(SparseMatrixCSC(size(A,1),size(A,2),Ti[],Ti[],promote_type(Tv,T)[]), A, b)
scale!(similar(A, promote_type(Tv,T)), A, b)

scale{T,Tv,Ti}(b::Vector{T}, A::SparseMatrixCSC{Tv,Ti}) =
scale!(SparseMatrixCSC(size(A,1),size(A,2),Ti[],Ti[],promote_type(Tv,T)[]), b, A)
scale!(similar(A, promote_type(Tv,T)), b, A)

chol(A::SparseMatrixCSC) = error("Use cholfact() instead of chol() for sparse matrices.")
lu(A::SparseMatrixCSC) = error("Use lufact() instead of lu() for sparse matrices.")
Expand Down
10 changes: 10 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,24 @@ end

# scale and scale!
sA = sprandn(3, 7, 0.5)
sC = similar(sA)
dA = full(sA)
b = randn(7)
@test scale(dA, b) == scale(sA, b)
@test scale(dA, b) == scale!(sC, sA, b)
@test scale(dA, b) == scale!(copy(sA), b)
b = randn(3)
@test scale(b, dA) == scale(b, sA)
@test scale(b, dA) == scale!(sC, b, sA)
@test scale(b, dA) == scale!(b, copy(sA))

@test scale(dA, 0.5) == scale(sA, 0.5)
@test scale(dA, 0.5) == scale!(sC, sA, 0.5)
@test scale(dA, 0.5) == scale!(copy(sA), 0.5)
@test scale(0.5, dA) == scale(0.5, sA)
@test scale(0.5, dA) == scale!(sC, sA, 0.5)
@test scale(0.5, dA) == scale!(0.5, copy(sA))

# reductions
@test sum(se33)[1] == 3.0
@test sum(se33, 1) == [1.0 1.0 1.0]
Expand Down

0 comments on commit 7d2b079

Please sign in to comment.