Skip to content

Commit

Permalink
Option to specify timeout for SCFs (JuliaMolSim#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst authored Jan 25, 2024
1 parent 5049ffb commit d2b7359
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 15 deletions.
3 changes: 2 additions & 1 deletion ext/DFTKJLD2Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ DFTK.make_subdict!(jld::Union{JLD2.Group,JLD2.JLDFile}, name::AbstractString) =
function save_jld2(to_dict_function!, file::AbstractString, scfres::NamedTuple;
save_ψ=true, save_ρ=true, extra_data=Dict{String,Any}(), compress=false)
if mpi_master()
JLD2.jldopen(file, "w"; compress) do jld
JLD2.jldopen(file * ".new", "w"; compress) do jld
to_dict_function!(jld, scfres; save_ψ, save_ρ)
for (k, v) in pairs(extra_data)
jld[k] = v
Expand All @@ -19,6 +19,7 @@ function save_jld2(to_dict_function!, file::AbstractString, scfres::NamedTuple;
delete!(jld, "kgrid")
jld["kgrid"] = scfres.basis.kgrid # Save original kgrid datastructure
end
mv(file * ".new", file; force=true)
else
dummy = Dict{String,Any}()
to_dict_function!(dummy, scfres; save_ψ)
Expand Down
5 changes: 3 additions & 2 deletions ext/DFTKJSON3Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ function save_json(todict_function, filename::AbstractString, scfres::NamedTuple
data[k] = v
end
if mpi_master()
open(filename, "w") do io
JSON3.pretty(io, data)
open(filename * ".new", "w") do io
JSON3.write(io, data)
end
mv(filename * ".new", filename; force=true)
end
MPI.Barrier(MPI.COMM_WORLD)
nothing
Expand Down
20 changes: 9 additions & 11 deletions src/scf/scf_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
# maxiter), where f(x) is the fixed-point map. It must return an
# object supporting res.sol and res.converged

# TODO max_iter could go to the solver generator function arguments

"""
Create a damped SCF solver updating the density as
`x = β * x_new + (1 - β) * x`
"""
function scf_damping_solver=0.2)
function fp_solver(f, x0, max_iter; tol=1e-6)
function fp_solver(f, x0, maxiter; tol=1e-6)
β = convert(eltype(x0), β)
converged = false
x = copy(x0)
for i = 1:max_iter
for i = 1:maxiter
x_new = f(x)

if norm(x_new - x) < tol
Expand All @@ -36,13 +34,13 @@ Create a simple anderson-accelerated SCF solver. `m` specifies the number
of steps to keep the history of.
"""
function scf_anderson_solver(m=10; kwargs...)
function anderson(f, x0, max_iter; tol=1e-6)
function anderson(f, x0, maxiter; tol=1e-6)
T = eltype(x0)
x = x0

converged = false
acceleration = AndersonAcceleration(; m, kwargs...)
for n = 1:max_iter
for n = 1:maxiter
residual = f(x) - x
converged = norm(residual) < tol
converged && break
Expand All @@ -57,7 +55,7 @@ CROP-accelerated root-finding iteration for `f`, starting from `x0` and keeping
a history of `m` steps. Optionally `warming` specifies the number of non-accelerated
steps to perform for warming up the history.
"""
function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
function CROP(f, x0, m::Int, maxiter::Int, tol::Real, warming=0)
# CROP iterates maintain xn and fn (/!\ fn != f(xn)).
# xtn+1 = xn + fn
# ftn+1 = f(xtn+1)
Expand All @@ -70,7 +68,7 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)

# Cheat support for multidimensional arrays
if length(size(x0)) != 1
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, max_iter, tol, warming)
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, maxiter, tol, warming)
return (; fixpoint=reshape(x, size(x0)...), converged=conv)
end
N = size(x0,1)
Expand All @@ -79,10 +77,10 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
fs = zeros(T, N, m+1) # newest to oldest
xs[:,1] = x0
fs[:,1] = f(x0) # Residual
errs = zeros(max_iter)
errs = zeros(maxiter)
err = Inf

for n = 1:max_iter
for n = 1:maxiter
xtnp1 = xs[:, 1] + fs[:, 1] # Richardson update
ftnp1 = f(xtnp1) # Residual
err = norm(ftnp1)
Expand Down Expand Up @@ -112,4 +110,4 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
end
(; fixpoint=xs[:, 1], converged=err < tol)
end
scf_CROP_solver(m=10) = (f, x0, max_iter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, max_iter, tol)
scf_CROP_solver(m=10) = (f, x0, maxiter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, maxiter, tol)
6 changes: 6 additions & 0 deletions src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
include("scf_callbacks.jl")
using Dates

"""
Transparently handle checkpointing by either returning kwargs for `self_consistent_field`,
Expand Down Expand Up @@ -109,6 +110,8 @@ Overview of parameters:
- `is_converged`: Convergence control callback. Typical objects passed here are
`ScfConvergenceDensity(tol)` (the default), `ScfConvergenceEnergy(tol)` or `ScfConvergenceForce(tol)`.
- `maxiter`: Maximal number of SCF iterations
- `maxtime`: Maximal time to run the SCF for. If this is reached without
convergence, the SCF stops.
- `mixing`: Mixing method, which determines the preconditioner ``P^{-1}`` in the above equation.
Typical mixings are [`LdosMixing`](@ref), [`KerkerMixing`](@ref), [`SimpleMixing`](@ref)
or [`DielectricMixing`](@ref). Default is `LdosMixing()`
Expand All @@ -129,6 +132,7 @@ Overview of parameters:
tol=1e-6,
is_converged=ScfConvergenceDensity(tol),
maxiter=100,
maxtime=Year(1),
mixing=LdosMixing(),
damping=0.8,
solver=scf_anderson_solver(),
Expand All @@ -152,6 +156,7 @@ Overview of parameters:
energies = nothing
ham = nothing
start_ns = time_ns()
end_time = Dates.now() + maxtime
info = (; n_iter=0, ρin=ρ) # Populate info with initial values
history_Etot = T[]
history_Δρ = T[]
Expand All @@ -161,6 +166,7 @@ Overview of parameters:
# TODO support other mixing types
function fixpoint_map(ρin)
converged && return ρin # No more iterations if convergence flagged
MPI.bcast(Dates.now() end_time, MPI.COMM_WORLD) && return ρin
n_iter += 1

# Note that ρin is not the density of ψ, and the eigenvalues
Expand Down
2 changes: 1 addition & 1 deletion test/scf_compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# Run other SCFs with SAD guess
ρ0 = guess_density(basis)
for solver in (scf_anderson_solver(), scf_damping_solver(1.0), scf_CROP_solver())
for solver in (scf_anderson_solver(), scf_damping_solver(), scf_CROP_solver())
@testset "Testing $solver" begin
ρ_alg = self_consistent_field(basis; ρ=ρ0, solver, tol).ρ
@test maximum(abs, ρ_alg - ρ_def) < 50tol
Expand Down

0 comments on commit d2b7359

Please sign in to comment.