Skip to content

Commit

Permalink
Implement SIMD computations (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
KeitaNakamura authored Feb 4, 2021
1 parent e4efc19 commit 7f1921c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.3.5"
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
2 changes: 2 additions & 0 deletions src/Tensorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export ⋅, ×, dot, tr, det, norm, mean, I, eigen, eigvals, eigvecs
using StaticArrays
using Base: @pure, @_inline_meta, @_propagate_inbounds_meta
using ForwardDiff: Dual, value, partials
import SIMD

import Base: transpose, inv
import LinearAlgebra: dot, norm, tr, adjoint, det, cross, eigen, eigvals, eigvecs
Expand Down Expand Up @@ -66,6 +67,7 @@ include("Tensor.jl")
include("ops.jl")
include("voigt.jl")
include("ad.jl")
include("simd.jl")

const = otimes
const = double_contraction
Expand Down
63 changes: 63 additions & 0 deletions src/simd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
const SIMDTypes = Union{Float16, Float32, Float64}

# TODO: implement more efficient computations for symmetric case
@generated function contraction(x::Tensor{<: Any, T}, y::Tensor{<: Any, T}, ::Val{N}) where {T <: SIMDTypes, N}
S1 = Size(x)
S2 = Size(y)
S = contraction(S1, S2, Val(N))
s1 = map(i -> EinsumIndex(:(Tuple(x)), i), independent_indices(S1))
s2 = map(i -> EinsumIndex(:(Tuple(y)), i), independent_indices(S2))
J = prod(size(s2)[1:N])
I = length(s1) ÷ J
K = length(s2) ÷ J
s1′ = reshape(s1, I, J)
s2′ = reshape(s2, J, K)
# Create slices
# [a c
# b d] => [[a,b], [c,d]]
slices = map(axes(s1′, 2)) do j
:($(Symbol(:v, j)) = SIMD.Vec($(map(construct_expr, s1′[:,j])...)))
end
columns = map(axes(s2′, 2)) do j
coefs = map(axes(s2′, 1)) do i
construct_expr(s2′[i, j])
end
code = :(v1 * $(coefs[1]))
for i in 2:length(coefs)
code = :(muladd($(Symbol(:v, i)), $(coefs[i]), $code))
end
code
end
exps = map(indices(S)) do i
d, r = divrem(i-1, I) .+ 1
:(columns[$d][$r])
end
if length(S) == 0
TT = T
else
TT = tensortype(S){T}
end
quote
@_inline_meta
@inbounds begin
$(slices...)
columns = tuple($(columns...))
$TT($(exps...))
end
end
end

for op in (:+, :-)
@eval @inline function Base.$op(x::TT, y::TT) where {T <: SIMDTypes, TT <: Tensor{<: Any, T}}
TT(Tuple($op(SIMD.Vec(Tuple(x)), SIMD.Vec(Tuple(y)))))
end
end

for op in (:*, :/)
@eval @inline function Base.$op(x::TT, a::T) where {T <: SIMDTypes, TT <: Tensor{<: Any, T}}
TT(Tuple($op(SIMD.Vec(Tuple(x)), a)))
end
end
@inline function Base.:*(a::T, x::TT) where {T <: SIMDTypes, TT <: Tensor{<: Any, T}}
x * a
end

0 comments on commit 7f1921c

Please sign in to comment.