Skip to content

Commit

Permalink
started thinking about new modes interface
Browse files Browse the repository at this point in the history
  • Loading branch information
marcsgil committed Jul 29, 2024
1 parent 36bfd75 commit 437506f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
48 changes: 48 additions & 0 deletions src/initial_profiles copy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function ψ!(dest, mode_kernel!, x, y; kwargs...)
backend = get_backend(dest)
kernel! = mode_kernel!(backend)
ψ!(dest, kernel!, x, y; kwargs..., ndrange=size(dest))
end

function ψ!(dest, mode_kernel!, normalization_func, x, y; kwargs...)
ψ!(dest, mode_kernel!, x, y; kwargs...)
N = normalization_func(kwargs...)
dest ./= N
end

function ψ(::Type{T}, args...; kwargs...) where {T}
dest = Matrix{T}(undef, length(x), length(y))
ψ!(dest, args...; kwargs...)
end

function ψ(args...; kwargs...)

end

#Laguerre-Gaussian modes

function _lg(x, y; p::Integer, l::Integer, γ=one(eltype(x)))
X = x / γ
Y = y / γ
r2 = X^2 + Y^2
L = abs(l)
exp(-r2 / 2) * (X + im * sign(l) * Y)^L * laguerre(r2, p, L)
end

"""
normalization_lg(p,l,γ=1)
Compute the normalization constant for the Laguerre-Gaussian modes.
"""
function normalization_lg(p, l, γ=1)
convert(float(eltype(γ)), inv(prod(p+1:p+abs(l)) * π) / γ)
end

@kernel function lg_kernel!(dest, x, y, p, l, γ)
j, k = @index(Global, NTuple)
dest[j, k] = _lg(x[j], y[k]; p, l, γ)
end

function lg(x, y)

end
34 changes: 20 additions & 14 deletions src/initial_profiles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function assert_hermite_indices(idxs...)
end
end

float_type(args...) = float(typeof(sum(first, args)))
float_type(args...) = promote_type((eltype(arg) for arg args)...) |> float

"""
normalization_hg(m,n,γ=1)
Expand Down Expand Up @@ -78,7 +78,7 @@ See also [`hg`](@ref), [`diagonal_hg`](@ref), [`lg`](@ref).
"""
function rotated_hg(x::Real, y::Real; θ, m::Integer=0, n::Integer=0, w=one(eltype(x)), include_normalization=true)
assert_hermite_indices(m, n)
T = float_type(x, y, w)
T = float_type(x, y, w, θ)
γ = convert(T, w / 2)
s, c = sincos(θ)

Expand All @@ -93,15 +93,14 @@ end

function rotated_hg(x, y; θ, m::Integer=0, n::Integer=0, w=one(eltype(x)), include_normalization=true)
assert_hermite_indices(m, n)
T = float_type(x, y, w)
T = float_type(x, y, w, θ)
γ = convert(T, w / 2)
s, c = sincos(θ)

@tullio result[j, k] := _hg(x[j], y[k], s, c; m, n, γ)
"""result = similar(x, length(x), length(y))
result = similar(x, length(x), length(y))
backend = get_backend(result)
kernel! = hg_kernel!(backend, 256)
kernel!(result, x, y, s, c, m, n, γ; ndrange=size(result))"""
kernel!(result, x, y, s, c, m, n, γ; ndrange=size(result))

if include_normalization
N = normalization_hg(m, n, γ)
Expand All @@ -115,7 +114,7 @@ function rotated_hg(x::Real, y::Real, z::Real;
θ, m::Integer=0, n::Integer=0, w=one(eltype(x)), k=one(eltype(x)), include_normalization=true)
assert_hermite_indices(m, n)

T = float(typeof(sum((x, y, z, w, k))))
T = float_type(x, y, z, w, k, θ)
γ = convert(T, w / 2)
k = convert(T, k)
s, c = sincos(θ)
Expand All @@ -133,7 +132,7 @@ function rotated_hg(x, y, z::Real;
θ, m::Integer=0, n::Integer=0, w=one(eltype(x)), k=one(eltype(x)), include_normalization=true)
assert_hermite_indices(m, n)

T = float_type(x, y, z, w, k)
T = float_type(x, y, z, w, k, θ)
γ = convert(T, w / 2)
k = convert(T, k)
s, c = sincos(θ)
Expand All @@ -152,7 +151,7 @@ function rotated_hg(x, y, z;
θ, m::Integer=0, n::Integer=0, w=one(eltype(x)), k=one(eltype(x)), include_normalization=true)
assert_hermite_indices(m, n)

T = float_type(x, y, z, w, k)
T = float_type(x, y, z, w, k, θ)
γ = convert(T, w / 2)
k = convert(T, k)
s, c = sincos(θ)
Expand Down Expand Up @@ -214,8 +213,8 @@ true
See also [`rotated_hg`](@ref), [`diagonal_hg`](@ref), [`lg`](@ref).
"""
hg(x, y; kwargs...) = rotated_hg(x, y; θ=zero(first(x)), kwargs...)
hg(x, y, z; kwargs...) = rotated_hg(x, y, z; θ=zero(first(x)), kwargs...)
hg(x, y; kwargs...) = rotated_hg(x, y; θ=zero(eltype(x)), kwargs...)
hg(x, y, z; kwargs...) = rotated_hg(x, y, z; θ=zero(eltype(x)), kwargs...)

"""
diagonal_hg(x, y; m::Integer=0, n::Integer=0, w=one(eltype(x)))
Expand Down Expand Up @@ -258,8 +257,8 @@ true
See also [`rotated_hg`](@ref), [`hg`](@ref), [`lg`](@ref).
"""
diagonal_hg(x, y; kwargs...) = rotated_hg(x, y; θ=oftype(float(first(x)), π / 4), kwargs...)
diagonal_hg(x, y, z; kwargs...) = rotated_hg(x, y, z; θ=oftype(float(first(x)), π / 4), kwargs...)
diagonal_hg(x, y; kwargs...) = rotated_hg(x, y; θ=convert(eltype(x), π / 4), kwargs...)
diagonal_hg(x, y, z; kwargs...) = rotated_hg(x, y, z; θ=convert(eltype(x), π / 4), kwargs...)

function _lg(x, y; p::Integer, l::Integer, γ=one(eltype(x)))
X = x / γ
Expand All @@ -286,6 +285,10 @@ function normalization_lg(p, l, γ=1)
convert(float(eltype(γ)), inv(prod(p+1:p+abs(l)) * π) / γ)
end

@kernel function lg_kernel!(dest, x, y, p, l, γ)
j, k = @index(Global, NTuple)
dest[j, k] = _lg(x[j], y[k]; p, l, γ)
end
"""
lg(x, y; p::Integer=0, l::Integer=0, w=one(eltype(x)))
lg(x, y, z; p::Integer=0, l::Integer=0, w=one(eltype(x)), k=one(eltype(x)))
Expand Down Expand Up @@ -346,7 +349,10 @@ function lg(x, y; p::Integer=0, l::Integer=0, w=one(eltype(x)), include_normaliz
T = float_type(x, y, w)
γ = convert(T, w / 2)

@tullio result[j, k] := _lg(x[j], y[k]; p, l, γ)
result = similar(x, complex(T), length(x), length(y))
backend = get_backend(result)
kernel! = lg_kernel!(backend, 256)
kernel!(result, x, y, p, l, γ; ndrange=size(result))

if include_normalization
N = normalization_lg(p, l, γ)
Expand Down

0 comments on commit 437506f

Please sign in to comment.