Skip to content

Commit

Permalink
[NDTensors] Introduce NestedPermutedDimsArrays submodule (#1589)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 15, 2024
1 parent dbec36b commit 3594216
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 0 deletions.
1 change: 1 addition & 0 deletions NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ for lib in [
:GradedAxes,
:SymmetrySectors,
:TensorAlgebra,
:NestedPermutedDimsArrays,
:SparseArrayInterface,
:SparseArrayDOKs,
:DiagonalArrays,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
# (though those are fully recursive).
module NestedPermutedDimsArrays

import Base: permutedims, permutedims!
export NestedPermutedDimsArray

# Some day we will want storage-order-aware iteration, so put perm in the parameters
struct NestedPermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
parent::AA

function NestedPermutedDimsArray{T,N,perm,iperm,AA}(
data::AA
) where {T,N,perm,iperm,AA<:AbstractArray}
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
error("perm and iperm must both be NTuple{$N,Int}")
isperm(perm) ||
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
all(d -> iperm[perm[d]] == d, 1:N) ||
throw(ArgumentError(string(perm, " and ", iperm, " must be inverses")))
return new(data)
end
end

"""
NestedPermutedDimsArray(A, perm) -> B
Given an AbstractArray `A`, create a view `B` such that the
dimensions appear to be permuted. Similar to `permutedims`, except
that no copying occurs (`B` shares storage with `A`).
See also [`permutedims`](@ref), [`invperm`](@ref).
# Examples
```jldoctest
julia> A = rand(3,5,4);
julia> B = NestedPermutedDimsArray(A, (3,1,2));
julia> size(B)
(4, 3, 5)
julia> B[3,1,2] == A[1,2,3]
true
```
"""
Base.@constprop :aggressive function NestedPermutedDimsArray(
data::AbstractArray{T,N}, perm
) where {T,N}
length(perm) == N ||
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
iperm = invperm(perm)
return NestedPermutedDimsArray{
PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T},
N,
(perm...,),
(iperm...,),
typeof(data),
}(
data
)
end

Base.parent(A::NestedPermutedDimsArray) = A.parent
function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
return genperm(size(parent(A)), perm)
end
function Base.axes(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
return genperm(axes(parent(A)), perm)
end
Base.has_offset_axes(A::NestedPermutedDimsArray) = Base.has_offset_axes(A.parent)
function Base.similar(A::NestedPermutedDimsArray, T::Type, dims::Base.Dims)
return similar(parent(A), T, dims)
end
function Base.cconvert(::Type{Ptr{T}}, A::NestedPermutedDimsArray{T}) where {T}
return Base.cconvert(Ptr{T}, parent(A))
end

# It's OK to return a pointer to the first element, and indeed quite
# useful for wrapping C routines that require a different storage
# order than used by Julia. But for an array with unconventional
# storage order, a linear offset is ambiguous---is it a memory offset
# or a linear index?
function Base.pointer(A::NestedPermutedDimsArray, i::Integer)
throw(
ArgumentError("pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray")
)
end

function Base.strides(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
s = strides(parent(A))
return ntuple(d -> s[perm[d]], Val(N))
end
function Base.elsize(::Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}) where {P}
return Base.elsize(P)
end

@inline function Base.getindex(
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
@inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm)
return val
end
@inline function Base.setindex!(
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
@inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...)
return val
end

function Base.isassigned(
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(Bool, A, I...) || return false
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
return x
end

@inline genperm(I::NTuple{N,Any}, perm::Dims{N}) where {N} = ntuple(d -> I[perm[d]], Val(N))
@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,))

function Base.copyto!(
dest::NestedPermutedDimsArray{T,N}, src::AbstractArray{T,N}
) where {T,N}
checkbounds(dest, axes(src)...)
return _copy!(dest, src)
end
Base.copyto!(dest::NestedPermutedDimsArray, src::AbstractArray) = _copy!(dest, src)

function _copy!(P::NestedPermutedDimsArray{T,N,perm}, src) where {T,N,perm}
# If dest/src are "close to dense," then it pays to be cache-friendly.
# Determine the first permuted dimension
d = 0 # d+1 will hold the first permuted dimension of src
while d < ndims(src) && perm[d + 1] == d + 1
d += 1
end
if d == ndims(src)
copyto!(parent(P), src) # it's not permuted
else
R1 = CartesianIndices(axes(src)[1:d])
d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest
R2 = CartesianIndices(axes(src)[(d + 2):(d1 - 1)])
R3 = CartesianIndices(axes(src)[(d1 + 1):end])
_permutedims!(P, src, R1, R2, R3, d + 1, d1)
end
return P
end

@noinline function _permutedims!(
P::NestedPermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
)
ip, is = axes(src, dp), axes(src, ds)
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
for I3 in R3, I2 in R2
for j in jo:min(jo + 7, last(ip))
for i in io:min(io + 7, last(is))
@inbounds P[i, I2, j, I3] = src[i, I2, j, I3]
end
end
end
end
return P
end

@noinline function _permutedims!(P::NestedPermutedDimsArray, src, R1, R2, R3, ds, dp)
ip, is = axes(src, dp), axes(src, ds)
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
for I3 in R3, I2 in R2
for j in jo:min(jo + 7, last(ip))
for i in io:min(io + 7, last(is))
for I1 in R1
@inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3]
end
end
end
end
end
return P
end

const CommutativeOps = Union{
typeof(+),
typeof(Base.add_sum),
typeof(min),
typeof(max),
typeof(Base._extrema_rf),
typeof(|),
typeof(&),
}

function Base._mapreduce_dim(
f, op::CommutativeOps, init::Base._InitialValue, A::NestedPermutedDimsArray, dims::Colon
)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
end
function Base._mapreduce_dim(
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
init::Base._InitialValue,
A::NestedPermutedDimsArray{<:Union{Real,Complex}},
dims::Colon,
)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
end

function Base.mapreducedim!(
f, op::CommutativeOps, B::AbstractArray{T,N}, A::NestedPermutedDimsArray{S,N,perm,iperm}
) where {T,S,N,perm,iperm}
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
end
function Base.mapreducedim!(
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
B::AbstractArray{T,N},
A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
) where {T,N,perm,iperm}
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
end

function Base.showarg(
io::IO, A::NestedPermutedDimsArray{T,N,perm}, toplevel
) where {T,N,perm}
print(io, "NestedPermutedDimsArray(")
Base.showarg(io, parent(A), false)
print(io, ", ", perm, ')')
toplevel && print(io, " with eltype ", eltype(A))
return nothing
end

end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/NestedPermutedDimsArrays/test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
23 changes: 23 additions & 0 deletions NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@eval module $(gensym())
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
using Test: @test, @testset
@testset "NestedPermutedDimsArrays" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
)
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
perm = (3, 2, 1)
p = NestedPermutedDimsArray(a, perm)
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
@test size(p) == (4, 3, 2)
@test eltype(p) === T
for I in eachindex(p)
@test size(p[I]) == (4, 3, 2)
@test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm)
end
x = randn(elt, 4, 3, 2)
p[3, 2, 1] = x
@test p[3, 2, 1] == x
@test a[1, 2, 3] == permutedims(x, perm)
end
end
1 change: 1 addition & 0 deletions NDTensors/test/lib/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using Test: @testset
"LabelledNumbers",
"MetalExtensions",
"NamedDimsArrays",
"NestedPermutedDimsArrays",
"SmallVectors",
"SortedSets",
"SparseArrayDOKs",
Expand Down

0 comments on commit 3594216

Please sign in to comment.