Skip to content

Commit

Permalink
Merge pull request HSU-ANT#78 from HSU-ANT/mh/improve-inferability
Browse files Browse the repository at this point in the history
Improve inferability
  • Loading branch information
martinholters authored Feb 11, 2022
2 parents 00c8e48 + 7724b54 commit 07777a5
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 122 deletions.
230 changes: 136 additions & 94 deletions src/ACME.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ include("elements.jl")

include("circuit.jl")

mutable struct DiscreteModel{Solvers}
mutable struct DiscreteModel{Solvers<:Tuple{Vararg{NonlinearSolver}}}
a::Matrix{Float64}
b::Matrix{Float64}
c::Matrix{Float64}
Expand All @@ -131,18 +131,14 @@ mutable struct DiscreteModel{Solvers}
solvers::Solvers
x::Vector{Float64}

function DiscreteModel(mats::Dict{Symbol}, nonlinear_eq_funcs::Vector,
function DiscreteModel(mats, nonlinear_eq_funcs::Vector,
solvers::Solvers) where {Solvers}
model = new{Solvers}()

for mat in (:a, :b, :c, :pexps, :dqs, :eqs, :fqprevs, :fqs, :dy, :ey, :fy, :x0, :q0s, :y0)
setfield!(model, mat, convert(fieldtype(typeof(model), mat), mats[mat]))
end

model.nonlinear_eq_funcs = nonlinear_eq_funcs
model.solvers = solvers
model.x = zeros(nx(model))
return model
return new{Solvers}(
mats[:a], mats[:b], mats[:c], mats[:x0],
mats[:pexps], mats[:dqs], mats[:eqs], mats[:fqprevs], mats[:fqs], mats[:q0s],
mats[:dy], mats[:ey], mats[:fy], mats[:y0],
nonlinear_eq_funcs, solvers, zeros(length(mats[:x0]))
)
end
end

Expand All @@ -159,12 +155,11 @@ function DiscreteModel(circ::Circuit, t::Real, ::Type{Solver}=HomotopySolver{Cac
end

model_nns = Int[sum(nns[nles]) for nles in nl_elems]
model_qidxs = [vcat(consecranges(nqs)[nles]...) for nles in nl_elems]
split_nl_model_matrices!(mats, model_qidxs, model_nns)
model_qidxs = [reduce(vcat, consecranges(nqs)[nles]) for nles in nl_elems]
mats = merge(mats, split_nl_model_matrices(mats, model_qidxs, model_nns))

reduce_pdims!(mats)
mats = reduce_pdims!(mats)

model_nps = size.(mats[:dqs], 1)
model_nqs = size.(mats[:pexps], 1)

@assert nn(circ) == sum(model_nns)
Expand All @@ -173,7 +168,7 @@ function DiscreteModel(circ::Circuit, t::Real, ::Type{Solver}=HomotopySolver{Cac
fqs = Matrix{Float64}.(mats[:fqs])
fqprev_fulls = Matrix{Float64}.(mats[:fqprev_fulls])

model_nonlinear_eq_funcs = [
model_nonlinear_eq_funcs = Function[
let q = zeros(nq), circ_nl_func = nonlinear_eq_func(circ, nles)
@inline function(res, J, pfull, Jq, fq, z)
#copyto!(q, pfull + fq * z)
Expand All @@ -188,27 +183,30 @@ function DiscreteModel(circ::Circuit, t::Real, ::Type{Solver}=HomotopySolver{Cac
end
end for (nles, nq) in zip(nl_elems, model_nqs)]

nonlinear_eq_funcs = [
nonlinear_eq_funcs = Function[
@inline function (res, J, scratch, z)
nleq(res, J, scratch[1], scratch[2], fq, z)
end for (nleq, fq) in zip(model_nonlinear_eq_funcs, fqs)]

init_zs = [zeros(nn) for nn in model_nns]
for idx in eachindex(nonlinear_eq_funcs)
q = q0s[idx] + fqprev_fulls[idx] * vcat(init_zs...)
q = q0s[idx] + fqprev_fulls[idx] * reduce(vcat, init_zs)
init_zs[idx] = initial_solution(nonlinear_eq_funcs[idx], q, model_nns[idx])
end

while any(np -> np == 0, model_nps)
const_idxs = findall(iszero, model_nps)
const_zidxs = vcat(consecranges(model_nns)[const_idxs]...)
while true
const_idxs = findall(iszero, size.(mats[:dqs], 1))
if isempty(const_idxs)
break
end
const_zidxs = reduce(vcat, consecranges(model_nns)[const_idxs])
varying_zidxs = filter(idx -> !(idx in const_zidxs), 1:sum(model_nns))
for idx in eachindex(mats[:q0s])
mats[:q0s][idx] += mats[:fqprev_fulls][idx][:,const_zidxs] * vcat(init_zs[const_idxs]...)
mats[:q0s][idx] += mats[:fqprev_fulls][idx][:,const_zidxs] * reduce(vcat, init_zs[const_idxs])
mats[:fqprev_fulls][idx] = mats[:fqprev_fulls][idx][:,varying_zidxs]
end
mats[:x0] += mats[:c][:,const_zidxs] * vcat(init_zs[const_idxs]...)
mats[:y0] += mats[:fy][:,const_zidxs] * vcat(init_zs[const_idxs]...)
copyto!(mats[:x0], mats[:x0] + mats[:c][:,const_zidxs] * reduce(vcat, init_zs[const_idxs]))
copyto!(mats[:y0], mats[:y0] + mats[:fy][:,const_zidxs] * reduce(vcat, init_zs[const_idxs]))
deleteat!(mats[:q0s], const_idxs)
deleteat!(mats[:dq_fulls], const_idxs)
deleteat!(mats[:eq_fulls], const_idxs)
Expand All @@ -220,11 +218,10 @@ function DiscreteModel(circ::Circuit, t::Real, ::Type{Solver}=HomotopySolver{Cac
deleteat!(model_nonlinear_eq_funcs, const_idxs)
deleteat!(nonlinear_eq_funcs, const_idxs)
deleteat!(nl_elems, const_idxs)
mats[:fy] = mats[:fy][:,varying_zidxs]
mats[:c] = mats[:c][:,varying_zidxs]
reduce_pdims!(mats)
model_nps = size.(mats[:dqs], 1)
mats = merge(mats, (fy=mats[:fy][:,varying_zidxs], c=mats[:c][:,varying_zidxs]))
mats = reduce_pdims!(mats)
end
model_nps = size.(mats[:dqs], 1)

q0s = Array{Float64}.(mats[:q0s])
fqs = Array{Float64}.(mats[:fqs])
Expand Down Expand Up @@ -260,50 +257,56 @@ function DiscreteModel(circ::Circuit, t::Real, ::Type{Solver}=HomotopySolver{Cac
end

function model_matrices(circ::Circuit, t::Rational{BigInt})
lhs = convert(SparseMatrixCSC{Rational{BigInt},Int},
[mv(circ) mi(circ) mxd(circ)//t+mx(circ)//2 mq(circ);
blockdiag(topomat(circ)...) spzeros(nb(circ), nx(circ) + nq(circ))])
rhs = convert(SparseMatrixCSC{Rational{BigInt},Int},
[u0(circ) mu(circ) mxd(circ)//t-mx(circ)//2;
spzeros(nb(circ), 1+nu(circ)+nx(circ))])
lhs = Rational{BigInt}[
mv(circ) mi(circ) mxd(circ)//t+mx(circ)//2 mq(circ);
blockdiag(topomat(circ)...) spzeros(nb(circ), nx(circ) + nq(circ))
]
rhs = Rational{BigInt}[
u0(circ) mu(circ) mxd(circ)//t-mx(circ)//2;
spzeros(Rational{BigInt}, nb(circ), 1+nu(circ)+nx(circ))
]
x, f = Matrix.(gensolve(lhs, rhs))

rowsizes = [nb(circ); nb(circ); nx(circ); nq(circ)]
res = Dict{Symbol,Array}(zip([:fv; :fi; :c; :fq], matsplit(f, rowsizes)))
rowsizes = (nb(circ), nb(circ), nx(circ), nq(circ))
rowranges = consecranges(rowsizes)
fq = f[rowranges[4], :]

nullspace = gensolve(sparse(res[:fq]::Matrix{Rational{BigInt}}),
spzeros(Rational{BigInt}, size(res[:fq],1), 0))[2]
nullspace = gensolve(sparse(fq), spzeros(Rational{BigInt}, size(fq, 1), 0))[2]
indeterminates = f * nullspace

if sum(abs2, res[:c] * nullspace) > 1e-20
if sum(abs2, indeterminates[rowranges[3],:]) > 1e-20
@warn "State update depends on indeterminate quantity"
end

while size(nullspace, 2) > 0
i, j = argmax(abs.(nullspace)).I
i, j = (argmax(abs.(nullspace))::CartesianIndex{2}).I # argmax cannot be inferred prior to Julia 1.7
nullspace = nullspace[[1:i-1; i+1:end], [1:j-1; j+1:end]]
f = f[:, [1:i-1; i+1:end]]
for k in [:fv; :fi; :c; :fq]
res[k] = res[k][:, [1:i-1; i+1:end]]
end
end

merge!(res, Dict(zip([:v0 :ev :dv; :i0 :ei :di; :x0 :b :a; :q0 :eq_full :dq_full],
matsplit(x, rowsizes, [1; nu(circ); nx(circ)]))))
for v in (:v0, :i0, :x0, :q0)
res[v] = dropdims(res[v], dims=2)
end
f_split = NamedTuple{(:fv, :fi, :c, :fq)}(matsplit(f, rowsizes))

x_split = NamedTuple{(:v0, :i0, :x0, :q0, :ev, :ei, :b, :eq_full, :dv, :di, :a, :dq_full)}(
matsplit(x, rowsizes, (1, nu(circ), nx(circ)))
)
vs = (:v0, :i0, :x0, :q0)
x_split_replace0 = NamedTuple{vs}(
map(let x_split=x_split; v -> dropdims(x_split[v], dims=2); end, vs)
)

p = [pv(circ) pi(circ) px(circ)//2+pxd(circ)//t pq(circ)]
if sum(abs2, p * indeterminates) > 1e-20
@warn "Model output depends on indeterminate quantity"
end
res[:dy] = p * x[:,2+nu(circ):end] + px(circ)//2-pxd(circ)//t
# p * [dv; di; a; dq_full] + px(circ)//2-pxd(circ)//t
res[:ey] = p * x[:,2:1+nu(circ)] # p * [ev; ei; b; eq_full]
res[:fy] = p * f # p * [fv; fi; c; fq]
res[:y0] = p * vec(x[:,1]) # p * [v0; i0; x0; q0]
y_split = (
dy = p * x[:,2+nu(circ):end] + px(circ)//2-pxd(circ)//t,
# p * [dv; di; a; dq_full] + px(circ)//2-pxd(circ)//t
ey = p * x[:,2:1+nu(circ)], # p * [ev; ei; b; eq_full]
fy = p * f, # p * [fv; fi; c; fq]
y0 = p * x[:,1], # p * [v0; i0; x0; q0]
)

return res
return merge(f_split, x_split, x_split_replace0, y_split)
end

model_matrices(circ::Circuit, t) = model_matrices(circ, Rational{BigInt}(t))
Expand Down Expand Up @@ -349,7 +352,7 @@ function nldecompose!(mats, nns, nqs)
while !isempty(rem_nles)
for sz in 1:length(rem_nles), sub in subsets(collect(rem_nles), sz)
nn_sub = sum(nns[sub])
a_update = tryextract(fq[[sub_ranges[sub]...;],rem_cols], nn_sub)
a_update = tryextract(fq[reduce(vcat, sub_ranges[sub]), rem_cols], nn_sub)
if a_update !== nothing
fq[:,rem_cols] = fq[:,rem_cols] * a_update
a[:,rem_cols] = a[:,rem_cols] * a_update
Expand All @@ -363,38 +366,49 @@ function nldecompose!(mats, nns, nqs)
end
end

mats[:c] = mats[:c] * a
copyto!(mats[:c], mats[:c] * a)
# mats[:fq] is updated as part of the loop
mats[:fy] = mats[:fy] * a
copyto!(mats[:fy], mats[:fy] * a)
return extracted_subs
end


function split_nl_model_matrices!(mats, model_qidxs, model_nns)
mats[:dq_fulls] = Matrix[mats[:dq_full][qidxs,:] for qidxs in model_qidxs]
mats[:eq_fulls] = Matrix[mats[:eq_full][qidxs,:] for qidxs in model_qidxs]
let fqsplit = vcat((matsplit(mats[:fq][qidxs,:], [length(qidxs)], model_nns) for qidxs in model_qidxs)...)
mats[:fqs] = Matrix[fqsplit[i,i] for i in 1:length(model_qidxs)]
mats[:fqprev_fulls] = Matrix[[fqsplit[i, 1:i-1]... zeros(eltype(mats[:fq]), length(model_qidxs[i]), sum(model_nns[i:end]))]
for i in 1:length(model_qidxs)]
end
mats[:q0s] = Vector[mats[:q0][qidxs] for qidxs in model_qidxs]
function split_nl_model_matrices(mats, model_qidxs, model_nns)
fqsplit = reduce(
vcat,
matsplit(mats[:fq][qidxs,:], [length(qidxs)], model_nns) for qidxs in model_qidxs;
init=Matrix{typeof(mats[:fq])}(undef, 0, length(model_nns))
)
return (
dq_fulls = [mats[:dq_full][qidxs,:] for qidxs in model_qidxs],
eq_fulls = [mats[:eq_full][qidxs,:] for qidxs in model_qidxs],
fqs = [fqsplit[i,i] for i in 1:length(model_qidxs)],
fqprev_fulls = [
foldr(
hcat,
fqsplit[i, 1:i-1];
init = zeros(eltype(mats[:fq]), length(model_qidxs[i]), sum(model_nns[i:end]))
)
for i in 1:length(model_qidxs)
],
q0s = [mats[:q0][qidxs] for qidxs in model_qidxs]
)
end

function reduce_pdims!(mats::Dict)
function reduce_pdims!(mats)
subcount = length(mats[:dq_fulls])
mats[:dqs] = Vector{Matrix}(undef, subcount)
mats[:eqs] = Vector{Matrix}(undef, subcount)
mats[:fqprevs] = Vector{Matrix}(undef, subcount)
mats[:pexps] = Vector{Matrix}(undef, subcount)
dqs = Vector{eltype(mats[:dq_fulls])}(undef, subcount)
eqs = Vector{eltype(mats[:eq_fulls])}(undef, subcount)
fqprevs = Vector{eltype(mats[:fqprev_fulls])}(undef, subcount)
pexps = Vector{Matrix{Rational{BigInt}}}(undef, subcount)
offset = 0
for idx in 1:subcount
# decompose [dq_full eq_full] into pexp*[dq eq] with [dq eq] having minimum
# number of rows
pexp, dqeq = rank_factorize(sparse([mats[:dq_fulls][idx] mats[:eq_fulls][idx] mats[:fqprev_fulls][idx]]))
mats[:pexps][idx] = pexp
colsizes = [size(mats[m][idx], 2) for m in [:dq_fulls, :eq_fulls, :fqprev_fulls]]
mats[:dqs][idx], mats[:eqs][idx], mats[:fqprevs][idx] = matsplit(dqeq, [size(dqeq, 1)], colsizes)
pexps[idx] = pexp
colsizes = map(m -> size(mats[m][idx], 2)::Int, (:dq_fulls, :eq_fulls, :fqprev_fulls))
dqs[idx], eqs[idx], fqprevs[idx] = matsplit(dqeq, (size(dqeq, 1),), colsizes)

# project pexp onto the orthogonal complement of the column space of Fq
fq = mats[:fqs][idx]
Expand All @@ -403,27 +417,32 @@ function reduce_pdims!(mats::Dict)
pexp = pexp - fq*fq_pinv*pexp
# if the new pexp has lower rank, update
pexp, f = rank_factorize(sparse(pexp))
if size(pexp, 2) < size(mats[:pexps][idx], 2)
if size(pexp, 2) < size(pexps[idx], 2)
cols = offset .+ (1:nn)
mats[:a] = mats[:a] - mats[:c][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:dqs][idx]
mats[:b] = mats[:b] - mats[:c][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:eqs][idx]
mats[:dy] = mats[:dy] - mats[:fy][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:dqs][idx]
mats[:ey] = mats[:ey] - mats[:fy][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:eqs][idx]
copyto!(mats[:a], mats[:a] - mats[:c][:,cols]*fq_pinv*pexps[idx]*dqs[idx])
copyto!(mats[:b], mats[:b] - mats[:c][:,cols]*fq_pinv*pexps[idx]*eqs[idx])
copyto!(mats[:dy], mats[:dy] - mats[:fy][:,cols]*fq_pinv*pexps[idx]*dqs[idx])
copyto!(mats[:ey], mats[:ey] - mats[:fy][:,cols]*fq_pinv*pexps[idx]*eqs[idx])
for idx2 in (idx+1):subcount
mats[:dq_fulls][idx2] = mats[:dq_fulls][idx2] - mats[:fqprev_fulls][idx2][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:dqs][idx]
mats[:eq_fulls][idx2] = mats[:eq_fulls][idx2] - mats[:fqprev_fulls][idx2][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:eqs][idx]
mats[:fqprev_fulls][idx2][:,1:offset] = mats[:fqprev_fulls][idx2][:,1:offset] - mats[:fqprev_fulls][idx2][:,cols]*fq_pinv*mats[:pexps][idx]*mats[:fqprevs][idx][:,1:offset]
q = mats[:fqprev_fulls][idx2][:,cols]*fq_pinv*pexps[idx]
copyto!(mats[:dq_fulls][idx2], mats[:dq_fulls][idx2] - q*dqs[idx])
copyto!(mats[:eq_fulls][idx2], mats[:eq_fulls][idx2] - q*eqs[idx])
copyto!(
mats[:fqprev_fulls][idx2][:,1:offset],
mats[:fqprev_fulls][idx2][:,1:offset] - q*fqprevs[idx][:,1:offset]
)
end
mats[:pexps][idx] = pexp
mats[:dqs][idx] = f * mats[:dqs][idx]
mats[:eqs][idx] = f * mats[:eqs][idx]
mats[:fqprevs][idx] = f * mats[:fqprevs][idx]
mats[:dq_fulls][idx] = pexp * mats[:dqs][idx]
mats[:eq_fulls][idx] = pexp * mats[:eqs][idx]
mats[:fqprev_fulls][idx] = pexp * mats[:fqprevs][idx]
pexps[idx] = pexp
dqs[idx] = f * dqs[idx]
eqs[idx] = f * eqs[idx]
fqprevs[idx] = f * fqprevs[idx]
copyto!(mats[:dq_fulls][idx], pexp * dqs[idx])
copyto!(mats[:eq_fulls][idx], pexp * eqs[idx])
copyto!(mats[:fqprev_fulls][idx], pexp * fqprevs[idx])
end
offset += nn
end
return merge(mats, (dqs = dqs, eqs=eqs, fqprevs=fqprevs, pexps=pexps))
end

function initial_solution(init_nl_eq_func::Function, q0, nn)
Expand Down Expand Up @@ -695,7 +714,7 @@ function gensolve(a, b, x, h, thresh=0.1)
if m == 0
return x, h
end
t = sortperm(vec(mapslices(ait -> count(!iszero, ait), a, dims=2))) # row indexes in ascending order of nnz
t = sortperm([count(!iszero, ait) for ait in eachrow(a)]) # row indexes in ascending order of nnz
tol = 3 * max(eps(float(eltype(a))), eps(float(eltype(h)))) * size(a, 2)
for i in 1:m
ait = a[t[i],:]' # ait is a row of the a matrix
Expand All @@ -707,7 +726,7 @@ function gensolve(a, b, x, h, thresh=0.1)
continue
end
jat = jnz[nz_abs_vals .≥ thresh*max_abs_val] # cols above threshold
j = jat[argmin(vec(mapslices(hj -> count(!iszero, hj), h[:,jat], dims=1)))]
j = jat[argmin([count(!iszero, hj) for hj in eachcol(h[:,jat])])]
q = h[:,j]
x = x + convert(typeof(x), q * ((b[t[i],:]' - ait*x) * (1 / (ait*q))))
if size(h)[2] > 1
Expand All @@ -727,7 +746,7 @@ function rank_factorize(a::SparseMatrixCSC)
nullspace = gensolve(a', spzeros(eltype(a), size(a, 2), 0))[2]
c = Matrix{eltype(a)}(I, size(a, 1), size(a, 1))
while size(nullspace, 2) > 0
i, j = argmax(abs.(nullspace)).I
i, j = (argmax(abs.(nullspace))::CartesianIndex{2}).I # argmax cannot be inferred prior to Julia 1.7
c -= c[:, i] * nullspace[:, j]' / nullspace[i, j]
c = c[:, [1:i-1; i+1:end]]
nullspace -= nullspace[:, j] * vec(nullspace[i, :])' / nullspace[i, j]
Expand All @@ -737,9 +756,18 @@ function rank_factorize(a::SparseMatrixCSC)
return c, f
end

consecranges(lengths) = map((l, e) -> (e-l+1):e, lengths, cumsum(lengths))
# cumsum(::Tuple) requires Julia 1.5, so roll our own version if needed, named
# _cumsum so not to commit type piracy
_cumsum(x) = cumsum(x)
if !hasmethod(cumsum, Tuple{Tuple})
_cumsum(x::Tuple) = Base.tail(foldl((r, xi) -> (r..., r[end] + xi), x; init=(false,)))
end

consecranges(lengths) = map((l, e) -> (e-l+1):e, lengths, _cumsum(lengths))

matsplit(v::AbstractVector, rowsizes) = [v[rs] for rs in consecranges(rowsizes)]
matsplit(m::AbstractMatrix, rowsizes::Tuple, colsizes::Tuple=(size(m,2),)) =
((map(cs -> map(rs -> m[rs, cs], consecranges(rowsizes)), consecranges(colsizes))...)...,)
matsplit(m::AbstractMatrix, rowsizes, colsizes=[size(m,2)]) =
[m[rs, cs] for rs in consecranges(rowsizes), cs in consecranges(colsizes)]

Expand Down Expand Up @@ -768,4 +796,18 @@ precompile(mosfet, (Symbol,))
precompile(opamp, ())
precompile(opamp, (Type{Val{:macak}}, Float64, Float64, Float64))

precompile(Circuit, ())
precompile(add!, (Circuit, Element))
precompile(add!, (Circuit, Symbol, Element))
for T1 in (Symbol, Tuple{Symbol,Int}, Tuple{Symbol,String}, Tuple{Symbol,Symbol}),
T2 in (
Symbol, Tuple{Symbol,Int}, Tuple{Symbol,String}, Tuple{Symbol,Symbol},
Vararg{Symbol}, Vararg{Tuple{Symbol,Int}}, Vararg{Tuple{Symbol,String}}, Vararg{Tuple{Symbol,Symbol}},
Vararg{Union{Symbol,Tuple{Symbol,Any}}},
)
precompile(connect!, (Circuit, T1, T2))
end

precompile(DiscreteModel, (Circuit, Rational{Int}))

end # module
Loading

0 comments on commit 07777a5

Please sign in to comment.