Skip to content

Commit

Permalink
Fix performance gotcha (JuliaMolSim#665)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael F. Herbst <[email protected]>
  • Loading branch information
antoine-levitt and mfherbst authored Jun 21, 2022
1 parent 4e8b4ba commit 7f848ae
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using TimerOutputs
using spglib_jll
using Unitful
using UnitfulAtomic
using ForwardDiff

export Vec3
export Mat3
Expand Down
15 changes: 9 additions & 6 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# A physical specification of a model.
# Contains the geometry information, but no discretization parameters.
# The exact model used is defined by the list of terms.
struct Model{T <: Real}
struct Model{T <: Real, VT <: Real}
# T is the default type to express data, VT the corresponding bare value type (i.e. not dual)

# Human-readable name for the model (like LDA, PBE, ...)
model_name::String

Expand Down Expand Up @@ -55,7 +57,7 @@ struct Model{T <: Real}
term_types::Vector

# list of symmetries of the model
symmetries::Vector{SymOp}
symmetries::Vector{SymOp{VT}}
end

_is_well_conditioned(A; tol=1e5) = (cond(A) <= tol)
Expand Down Expand Up @@ -156,10 +158,11 @@ function Model(lattice::AbstractMatrix{T}, atoms=Element[], positions=Vec3{T}[];
end
@assert !isempty(symmetries) # Identity has to be always present.

Model{T}(model_name, lattice, recip_lattice, n_dim, inv_lattice, inv_recip_lattice,
unit_cell_volume, recip_cell_volume,
n_electrons, spin_polarization, n_spin, T(temperature), smearing,
atoms, positions, atom_groups, terms, symmetries)
VT = ForwardDiff.valtype(T)
Model{T,VT}(model_name, lattice, recip_lattice, n_dim, inv_lattice, inv_recip_lattice,
unit_cell_volume, recip_cell_volume,
n_electrons, spin_polarization, n_spin, T(temperature), smearing,
atoms, positions, atom_groups, terms, symmetries)
end
function Model(lattice::AbstractMatrix{<: Integer}, args...; kwargs...)
Model(Float64.(lattice), args...; kwargs...)
Expand Down
11 changes: 6 additions & 5 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ Normalization conventions:
`G_to_r` and `r_to_G` convert between these representations.
"""
struct PlaneWaveBasis{T} <: AbstractBasis{T}
model::Model{T}
struct PlaneWaveBasis{T, VT} <: AbstractBasis{T} where {VT <: Real}
# T is the default type to express data, VT the corresponding bare value type (i.e. not dual)
model::Model{T, VT}

## Global grid information
# fft_size defines both the G basis on which densities and
Expand Down Expand Up @@ -90,7 +91,7 @@ struct PlaneWaveBasis{T} <: AbstractBasis{T}

## Symmetry operations that leave the discretized model (k and r grids) invariant.
# Subset of model.symmetries.
symmetries::Vector{SymOp}
symmetries::Vector{SymOp{VT}}
# Whether the symmetry operations leave the rgrid invariant
# If this is true, the symmetries are a property of the complete discretized model.
# Therefore, all quantities should be symmetric to machine precision
Expand Down Expand Up @@ -176,7 +177,7 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
# Manual kpoint set based on kcoords/kweights
@assert length(kcoords) == length(kweights)
all_kcoords = unfold_kcoords(kcoords, symmetries)
symmetries = symmetries_preserving_kgrid(symmetries, all_kcoords)
symmetries = symmetries_preserving_kgrid(symmetries, all_kcoords)
end

# Init MPI, and store MPI-global values for reference
Expand Down Expand Up @@ -241,7 +242,7 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,

dvol = model.unit_cell_volume ./ prod(fft_size)
terms = Vector{Any}(undef, length(model.term_types)) # Dummy terms array, filled below
basis = PlaneWaveBasis{T}(
basis = PlaneWaveBasis{T,ForwardDiff.valtype(T)}(
model, fft_size, dvol,
Ecut, variational,
opFFT, ipFFT, opBFFT, ipBFFT,
Expand Down
3 changes: 1 addition & 2 deletions src/Smearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
# "Numerical quadrature in the brillouin zone for periodic schrodinger operators"
# See also https://www.vasp.at/vasp-workshop/k-points.pdf
module Smearing

using SpecialFunctions: erf, erfc, factorial
import ForwardDiff
using SpecialFunctions: erf, erfc, factorial

abstract type SmearingFunction end

Expand Down
15 changes: 9 additions & 6 deletions src/SymOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ struct SymOp{T <: Real}
# (Uu)(G) = e^{-i G τ} u(S^-1 G) in reciprocal space
S::Mat3{Int}
τ::Vec3{T}
end
function SymOp(W, w::AbstractVector{T}) where {T}
w = mod.(w, 1)
S = W'
τ = -W \ w
SymOp{T}(W, w, S, τ)
end

function SymOp(W, w)
w = mod.(w, 1)
S = W'
τ = -W \w
new{eltype(τ)}(W, w, S, τ)
end
function Base.convert(::Type{SymOp{T}}, S::SymOp{U}) where {U <: Real, T <: Real}
SymOp{T}(S.W, T.(S.w), S.S, T.(S.τ))
end

Base.:(==)(op1::SymOp, op2::SymOp) = op1.W == op2.W && op1.w == op2.w
Expand Down
2 changes: 0 additions & 2 deletions src/densities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using ForwardDiff

# Densities (and potentials) are represented by arrays
# ρ[ix,iy,iz,iσ] in real space, where iσ ∈ [1:n_spin_components]

Expand Down
8 changes: 4 additions & 4 deletions src/external/jld2io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ load_scfres(file::AbstractString) = JLD2.jldopen(load_scfres, file, "r")
# Custom serialisations
#
struct PlaneWaveBasisSerialisation{T <: Real}
model::Model{T}
model::Model{T,T}
Ecut::T
variational::Bool
kcoords::Vector{Vec3{T}}
Expand All @@ -85,9 +85,9 @@ struct PlaneWaveBasisSerialisation{T <: Real}
symmetries_respect_rgrid::Bool
fft_size::Tuple{Int, Int, Int}
end
JLD2.writeas(::Type{PlaneWaveBasis{T}}) where {T} = PlaneWaveBasisSerialisation{T}
JLD2.writeas(::Type{PlaneWaveBasis{T,T}}) where {T} = PlaneWaveBasisSerialisation{T}

function Base.convert(::Type{PlaneWaveBasisSerialisation{T}}, basis::PlaneWaveBasis{T}) where {T}
function Base.convert(::Type{PlaneWaveBasisSerialisation{T}}, basis::PlaneWaveBasis{T,T}) where {T}
PlaneWaveBasisSerialisation{T}(
basis.model,
basis.Ecut,
Expand All @@ -101,7 +101,7 @@ function Base.convert(::Type{PlaneWaveBasisSerialisation{T}}, basis::PlaneWaveBa
)
end

function Base.convert(::Type{PlaneWaveBasis{T}}, serial::PlaneWaveBasisSerialisation{T}) where {T}
function Base.convert(::Type{PlaneWaveBasis{T,T}}, serial::PlaneWaveBasisSerialisation{T}) where {T}
PlaneWaveBasis(serial.model, serial.Ecut, serial.kcoords,
serial.kweights; serial.fft_size,
serial.kgrid, serial.kshift, serial.symmetries_respect_rgrid,
Expand Down
1 change: 0 additions & 1 deletion src/postprocess/stresses.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using ForwardDiff
"""
Compute the stresses (= 1/Vol dE/d(M*lattice), taken at M=I) of an obtained SCF solution.
"""
Expand Down
1 change: 0 additions & 1 deletion src/terms/local_nonlinearity.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using ForwardDiff
"""
Local nonlinearity, with energy ∫f(ρ) where ρ is the density
"""
Expand Down

0 comments on commit 7f848ae

Please sign in to comment.