Skip to content

Commit

Permalink
Load generic FFT fallbacks in FourierTransforms only optionally
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst committed Mar 16, 2020
1 parent 662b800 commit cacb5b0
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 106 deletions.
10 changes: 10 additions & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ function __init__()
@require IterativeSolvers="42fd0dbc-a981-5370-80f2-aaf504508153" begin
include("eigen/diag_lobpcg_itsolve.jl")
end

# Load generic FFT stuff once IntervalArithmetic or DoubleFloats are used
# The global variable GENERIC_FFT_LOADED makes sure that things are only
# included once.
@require IntervalArithmetic="d1acc4aa-44c8-5952-acd4-ba5d80a2a253" begin
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("fft_generic.jl")
end
@require DoubleFloats="497a8b3b-efae-58df-a0af-a86822472b78" begin
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("fft_generic.jl")
end
end

end # module DFTK
7 changes: 2 additions & 5 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,8 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number,
fft_size = Tuple{Int, Int, Int}(fft_size)

# TODO generic FFT is kind of broken for some fft sizes
# ... temporary workaround, see more details in fft.jl
if !(T in [Float32, Float64]) && !all(is_fft_size_ok_for_generic.(fft_size))
fft_size = next_working_fft_size_for_generic.(fft_size)
@info "Changing fft size to $fft_size (smallest working size for generic FFTs)"
end
# ... temporary workaround, see more details in fft_generic.jl
fft_size = next_working_fft_size.(T, fft_size)
ipFFT, opFFT = build_fft_plans(T, fft_size)

# The FFT interface specifies that fft has no normalization, and
Expand Down
120 changes: 19 additions & 101 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import FFTW
import Primes
include("FourierTransforms.jl/FourierTransforms.jl")

@doc raw"""
Determine the minimal grid size for the cubic basis set to be able to
Expand Down Expand Up @@ -40,109 +38,29 @@ end


"""
Plan a FFT of type `T` and size `fft_size`, spending some time on finding an optimal algorithm.
Both an inplace and an out-of-place FFT plan are returned.
Plan a FFT of type `T` and size `fft_size`, spending some time on finding an
optimal algorithm. Both an inplace and an out-of-place FFT plan are returned.
"""
function build_fft_plans(T, fft_size)
tmp = Array{Complex{T}}(undef, fft_size...)
if T == Float64
ipFFT = FFTW.plan_fft!(tmp, flags=FFTW.MEASURE)
opFFT = FFTW.plan_fft(tmp, flags=FFTW.MEASURE)
return ipFFT, opFFT
elseif T == Float32
# TODO For Float32 there are issues with aligned FFTW plans.
# Using unaligned FFTW plans is discouraged, but we do it anyways
# here as a quick fix. We should reconsider this in favour of using
# a parallel wisdom anyways in the future.
ipFFT = FFTW.plan_fft!(tmp, flags=FFTW.MEASURE | FFTW.UNALIGNED)
opFFT = FFTW.plan_fft(tmp, flags=FFTW.MEASURE | FFTW.UNALIGNED)
return ipFFT, opFFT
end

# Fall back to FourierTransforms
# Note: FourierTransforms has no support for in-place FFTs at the moment
# ... also it's extension to multi-dimensional arrays is broken and
# the algo only works for some cases
@assert all(is_fft_size_ok_for_generic.(fft_size))

# opFFT = FourierTransforms.plan_fft(tmp) # TODO When multidim works
opFFT = generic_plan_fft(tmp) # Fallback for now
# TODO Can be cut once FourierTransforms supports AbstractFFTs properly
ipFFT = DummyInplace{typeof(opFFT)}(opFFT)

function build_fft_plans(T::Type{Float32}, fft_size)
# TODO For Float32 there are issues with aligned FFTW plans, so we
# fall back to unaligned FFTW plans (which are generally discouraged).
tmp = Array{ComplexF32}(undef, fft_size...)
ipFFT = FFTW.plan_fft!(tmp, flags=FFTW.MEASURE | FFTW.UNALIGNED)
opFFT = FFTW.plan_fft(tmp, flags=FFTW.MEASURE | FFTW.UNALIGNED)
ipFFT, opFFT
end

# Utility functions to setup FFTs for DFTK. Most functions in here
# are needed to correct for the fact that FourierTransforms is not
# yet fully compliant with the AbstractFFTs interface and has still
# various bugs we work around.

function is_fft_size_ok_for_generic(size::Integer)
# TODO FourierTransforms has a bug, which is triggered
# only in some factorisations, see
# https://github.com/JuliaComputing/FourierTransforms.jl/issues/10
# Everything is fine if we have up to one prime factor,
# which is not two, also we want to avoid large primes
#
# Actually ... it seems to only reliably for powers of two.
all(k == 2 for (k, v) in Primes.factor(size))
end

function next_working_fft_size_for_generic(size)
while !is_fft_size_ok_for_generic(size)
size += 1
end
size
end

struct GenericPlan{T}
subplans
factor::T
end

function generic_apply(p::GenericPlan, X::AbstractArray)
pl1, pl2, pl3 = p.subplans
ret = similar(X)
for i in 1:size(X, 1), j in 1:size(X, 2)
@views ret[i, j, :] .= pl3 * X[i, j, :]
end
for i in 1:size(X, 1), k in 1:size(X, 3)
@views ret[i, :, k] .= pl2 * ret[i, :, k]
end
for j in 1:size(X, 2), k in 1:size(X, 3)
@views ret[:, j, k] .= pl1 * ret[:, j, k]
end
p.factor .* ret
end

LinearAlgebra.mul!(Y, p::GenericPlan, X) = Y .= p * X
LinearAlgebra.ldiv!(Y, p::GenericPlan, X) = Y .= p \ X

import Base: *, \, inv, length
length(p::GenericPlan) = prod(length, p.subplans)
*(p::GenericPlan, X::AbstractArray) = generic_apply(p, X)
*(p::GenericPlan{T}, fac::Number) where T = GenericPlan{T}(p.subplans, p.factor * T(fac))
*(fac::Number, p::GenericPlan{T}) where T = p * fac
\(p::GenericPlan, X) = inv(p) * X
inv(p::GenericPlan{T}) where T = GenericPlan{T}(inv.(p.subplans), 1 / p.factor)

function generic_plan_fft(data::AbstractArray{T, 3}) where T
GenericPlan{T}([FourierTransforms.plan_fft(data[:, 1, 1]),
FourierTransforms.plan_fft(data[1, :, 1]),
FourierTransforms.plan_fft(data[1, 1, :])], T(1))
function build_fft_plans(T::Type{Float64}, fft_size)
tmp = Array{ComplexF64}(undef, fft_size...)
ipFFT = FFTW.plan_fft!(tmp, flags=FFTW.MEASURE)
opFFT = FFTW.plan_fft(tmp, flags=FFTW.MEASURE)
ipFFT, opFFT
end


# A dummy wrapper around an out-of-place FFT plan to make it appear in-place
# This is needed for some generic FFT implementations, which do not have in-place plans
struct DummyInplace{opFFT}
fft::opFFT
end
LinearAlgebra.mul!(Y, p::DummyInplace, X) = (Y .= mul!(similar(X), p.fft, X))
LinearAlgebra.ldiv!(Y, p::DummyInplace, X) = (Y .= ldiv!(similar(X), p.fft, X))

import Base: *, \, length
*(p::DummyInplace, X) = p.fft * X
\(p::DummyInplace, X) = p.fft \ X
length(p::DummyInplace) = length(p.fft)
# TODO Some grid sizes are broken in the generic FFT implementation
# in FourierTransforms, for more details see fft_generic.jl
# This function is needed to provide a noop fallback for grid adjustment for
# for floating-point types natively supported by FFTW
next_working_fft_size(::Type{Float32}, size) = size
next_working_fft_size(::Type{Float64}, size) = size
93 changes: 93 additions & 0 deletions src/fft_generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
include("FourierTransforms.jl/FourierTransforms.jl")

# This is needed to flag that the fft_generic.jl file has already been loaded
const GENERIC_FFT_LOADED = true


# Utility functions to setup FFTs for DFTK. Most functions in here
# are needed to correct for the fact that FourierTransforms is not
# yet fully compliant with the AbstractFFTs interface and has still
# various bugs we work around.

function next_working_fft_size(::Any, size)
# TODO FourierTransforms has a bug, which is triggered
# only in some factorisations, see
# https://github.com/JuliaComputing/FourierTransforms.jl/issues/10
# To be safe we fall back to powers of two

adjusted = nextpow(2, size)
if adjusted != size
@info "Changing fft size to $fft_size (smallest working size for generic FFTs)"
end
adjusted
end

# Generic fallback function, Float32 and Float64 specialisation in fft.jl
function build_fft_plans(T, fft_size)
tmp = Array{Complex{T}}(undef, fft_size...)

# Note: FourierTransforms has no support for in-place FFTs at the moment
# ... also it's extension to multi-dimensional arrays is broken and
# the algo only works for some cases
@assert all(ispow2, fft_size)

# opFFT = FourierTransforms.plan_fft(tmp) # TODO When multidim works
opFFT = generic_plan_fft(tmp) # Fallback for now
# TODO Can be cut once FourierTransforms supports AbstractFFTs properly
ipFFT = DummyInplace{typeof(opFFT)}(opFFT)

ipFFT, opFFT
end



struct GenericPlan{T}
subplans
factor::T
end

function generic_apply(p::GenericPlan, X::AbstractArray)
pl1, pl2, pl3 = p.subplans
ret = similar(X)
for i in 1:size(X, 1), j in 1:size(X, 2)
@views ret[i, j, :] .= pl3 * X[i, j, :]
end
for i in 1:size(X, 1), k in 1:size(X, 3)
@views ret[i, :, k] .= pl2 * ret[i, :, k]
end
for j in 1:size(X, 2), k in 1:size(X, 3)
@views ret[:, j, k] .= pl1 * ret[:, j, k]
end
p.factor .* ret
end

LinearAlgebra.mul!(Y, p::GenericPlan, X) = Y .= p * X
LinearAlgebra.ldiv!(Y, p::GenericPlan, X) = Y .= p \ X

import Base: *, \, inv, length
length(p::GenericPlan) = prod(length, p.subplans)
*(p::GenericPlan, X::AbstractArray) = generic_apply(p, X)
*(p::GenericPlan{T}, fac::Number) where T = GenericPlan{T}(p.subplans, p.factor * T(fac))
*(fac::Number, p::GenericPlan{T}) where T = p * fac
\(p::GenericPlan, X) = inv(p) * X
inv(p::GenericPlan{T}) where T = GenericPlan{T}(inv.(p.subplans), 1 / p.factor)

function generic_plan_fft(data::AbstractArray{T, 3}) where T
GenericPlan{T}([FourierTransforms.plan_fft(data[:, 1, 1]),
FourierTransforms.plan_fft(data[1, :, 1]),
FourierTransforms.plan_fft(data[1, 1, :])], T(1))
end


# A dummy wrapper around an out-of-place FFT plan to make it appear in-place
# This is needed for some generic FFT implementations, which do not have in-place plans
struct DummyInplace{opFFT}
fft::opFFT
end
LinearAlgebra.mul!(Y, p::DummyInplace, X) = (Y .= mul!(similar(X), p.fft, X))
LinearAlgebra.ldiv!(Y, p::DummyInplace, X) = (Y .= ldiv!(similar(X), p.fft, X))

import Base: *, \, length
*(p::DummyInplace, X) = p.fft * X
\(p::DummyInplace, X) = p.fft \ X
length(p::DummyInplace) = length(p.fft)

0 comments on commit cacb5b0

Please sign in to comment.