Skip to content

Commit

Permalink
Add some missing fallbacks for interval arithmetic (JuliaMolSim#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst authored Jul 23, 2021
1 parent d26aaea commit a3ca5ed
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function Model(lattice::AbstractMatrix{T};
norm(lattice[:, i]) == norm(lattice[i, :]) == 0 || error(
"For 1D and 2D systems, the non-empty dimensions must come first")
end
_check_well_conditioned(lattice[1:n_dim, 1:n_dim]) || @warn (
_is_well_conditioned(lattice[1:n_dim, 1:n_dim]) || @warn (
"Your lattice is badly conditioned, the computation is likely to fail.")

# Compute reciprocal lattice and volumes.
Expand Down Expand Up @@ -214,4 +214,4 @@ function spin_components(spin_polarization::Symbol)
end
spin_components(model::Model) = spin_components(model.spin_polarization)

_check_well_conditioned(A; tol=1e5) = (cond(A) <= tol)
_is_well_conditioned(A; tol=1e5) = (cond(A) <= tol)
4 changes: 2 additions & 2 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ function spglib_get_symmetry(lattice::Matrix{<:ForwardDiff.Dual}, atoms, magneti
spglib_get_symmetry(ForwardDiff.value.(lattice), atoms, magnetic_moments; kwargs...)
end

function _check_well_conditioned(A::AbstractArray{<:ForwardDiff.Dual}; kwargs...)
_check_well_conditioned(ForwardDiff.value.(A); kwargs...)
function _is_well_conditioned(A::AbstractArray{<:ForwardDiff.Dual}; kwargs...)
_is_well_conditioned(ForwardDiff.value.(A); kwargs...)
end


Expand Down
14 changes: 13 additions & 1 deletion src/workarounds/intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import IntervalArithmetic: Interval, mid
# should be done e.g. by changing the rounding mode ...
erfc(i::Interval) = Interval(prevfloat(erfc(i.lo)), nextfloat(erfc(i.hi)))

function compute_fft_size(lattice::AbstractMatrix{T}, Ecut; kwargs...) where T <: Interval
function compute_fft_size(lattice::AbstractMatrix{<:Interval}, Ecut; kwargs...)
# This is done to avoid a call like ceil(Int, ::Interval)
# in the above implementation of compute_fft_size,
# where it is in general cases not clear, what to do.
Expand All @@ -17,6 +17,18 @@ function compute_fft_size(lattice::AbstractMatrix{T}, Ecut; kwargs...) where T <
compute_fft_size(mid.(lattice), Ecut; kwargs...)
end

function _is_well_conditioned(A::AbstractArray{<:Interval}; kwargs...)
# This check is used during the lattice setup, where it frequently fails with intervals
# (because doing an SVD with intervals leads to a large overestimation of the rounding error)
_is_well_conditioned(mid.(A); kwargs...)
end

function symmetry_operations(lattice::AbstractMatrix{<:Interval}, atoms, magnetic_moments=[];
tol_symmetry=max(1e-5, maximum(radius, lattice)))
@assert tol_symmetry < 1e-2
symmetry_operations(mid.(lattice), atoms, magnetic_moments; tol_symmetry)
end

function local_potential_fourier(el::ElementCohenBergstresser, q::T) where {T <: Interval}
lor = round(q.lo, digits=5)
hir = round(q.hi, digits=5)
Expand Down
3 changes: 2 additions & 1 deletion test/interval_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function discretized_hamiltonian(T, testcase)
spec = ElementPsp(testcase.atnum, psp=load_psp(testcase.psp))
atoms = [spec => testcase.positions]
# disable symmetry for interval
model = model_DFT(Array{T}(testcase.lattice), atoms, [:lda_x, :lda_c_vwn], symmetries=false)
model = model_DFT(Array{T}(testcase.lattice), atoms, [:lda_x, :lda_c_vwn])

# For interval arithmetic to give useful numbers,
# the fft_size should be a power of 2
Expand All @@ -25,6 +25,7 @@ end
T = Float64
ham = discretized_hamiltonian(T, silicon)
hamInt = discretized_hamiltonian(Interval{T}, silicon)
@test length(ham.basis.model.symmetries) == length(hamInt.basis.model.symmetries)

hamk = ham.blocks[1]
hamIntk = hamInt.blocks[1]
Expand Down

0 comments on commit a3ca5ed

Please sign in to comment.