Skip to content

Commit

Permalink
extend forwarddiff svec norm for partials (JuliaMolSim#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasschmitz authored Jul 17, 2021
1 parent c6da218 commit 0186c27
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,9 @@ end
# solution: follow ChainRules custom frule for norm
# https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/LinearAlgebra/norm.jl#L5
# TODO delete, once forward diff AD tools use ChainRules natively
function LinearAlgebra.norm(x::SVector{S,<:ForwardDiff.Dual}) where {S}
T = ForwardDiff.tagtype(eltype(x))
dx = ForwardDiff.partials.(x)
y = norm(ForwardDiff.value.(x))
dy = real(dot(ForwardDiff.value.(x), dx)) * pinv(y)
ForwardDiff.Dual{T}(y, dy)
function LinearAlgebra.norm(x::SVector{S,<:ForwardDiff.Dual{Tg,T,N}}) where {S,Tg,T,N}
x_value = ForwardDiff.value.(x)
y = norm(x_value)
dy = ntuple(j->real(dot(x_value, ForwardDiff.partials.(x,j))) * pinv(y), N)
ForwardDiff.Dual{Tg}(y, dy)
end

0 comments on commit 0186c27

Please sign in to comment.