Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Magic Dots #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ version = "0.0.1"

[deps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
DiffRules = "1"
EllipsisNotation = "0.4"
Requires = "1"
julia = "1.3"

Expand Down
3 changes: 3 additions & 0 deletions src/Tullio.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module Tullio

using EllipsisNotation: (..)
using Base.Broadcast: newindex, newindexer, combine_axes # , check_broadcast_axes

#========== ⚜️ ==========#

export @tullio
Expand Down
137 changes: 115 additions & 22 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,19 @@ function _tullio(exs...; mod=Main)
# Everything writes into leftarray[leftraw...], sometimes with a generated name
leftraw = [],
leftind = Symbol[], # vcat(leftind, redind) is the complete list of loop indices
leftcast = nothing,
leftarray = nothing,
leftscalar = nothing, # only defined for scalar reduction
leftnames = Symbol[], # for NamedDims
# Whole RHS, untouched, plus things extracted:
right = nothing,
rightind = Symbol[],
sharedind = Symbol[], # indices appearing on every RHS array
rightcast = Expr[], # things like @view A[end,end,..] for broadcasting
arrays = Symbol[],
scalars = Symbol[],
broadargs = Symbol[], # results of newindexer, passed to actor as tuple BROAD
broadarrays = Symbol[], # arrays for which that must be unpacked for newindex
cost = 1,
# Index ranges: first save all known constraints
constraints = Dict{Symbol,Vector}(), # :k => [:(axis(A,2)), :(axis(B,1))] etc.
Expand Down Expand Up @@ -203,11 +207,12 @@ end

# These only need not to clash with symbols in the input:
RHS, AXIS = :𝓇𝒽𝓈, :𝒶𝓍
ZED, TYP, ACC, KEEP = :ℛ, :𝒯, :𝒜𝒸𝒸, :♻
EPS, DEL, EXPR = :𝜀, :𝛥, :ℰ𝓍
ZED, TYP, ACC, KEEP = :ℛ, :𝒯, :𝒜𝒸𝒸, :♻ # used in act! function
EPS, DEL, EXPR = :𝜀, :𝛥, :ℰ𝓍 # used for derivatives
CART, IND, BOOL, FIRST, BROAD = :𝒞𝒶𝓇𝓉, :𝒾𝓃𝒹, :𝒷ℴℴ𝓁, :𝒻𝒾𝓇𝓈𝓉, :𝒷𝓇ℴ𝒶𝒹 # used for broadcasting ⚃

# These get defined globally, with a random number appended:
MAKE, ACT! = :𝒞𝓇ℯ𝒶𝓉ℯ, :𝒜𝒸𝓉! # :ℳ𝒶𝓀ℯ
MAKE, ACT! = :ℳ𝒶𝓀ℯ, :𝒜𝒸𝓉!

#========== input parsing ==========#

Expand Down Expand Up @@ -250,6 +255,7 @@ function parse_input(expr, store)
unique!(store.leftind) # after last saveconstraints()
unique!(store.sharedind)
unique!(store.rightind)
unique!(store.broadargs)
unique!(store.outpre) # kill mutiple assertions, and evaluate any f(A) only once

store.redind = setdiff(store.rightind, store.leftind)
Expand Down Expand Up @@ -281,7 +287,7 @@ rightwalk(store) = ex -> begin
# Third, save letter A, and what axes(A) says about indices:
push!(store.arrays, arrayonly(A))
inds = primeindices(inds)
saveconstraints(A, inds, store, true)
inds = saveconstraints(A, inds, store, true)

# Re-assemble RHS with new A, and primes on indices taken care of.
return :( $A[$(inds...)] )
Expand All @@ -297,7 +303,25 @@ saveconstraints(A, inds, store, right=true) = begin
A1 = arrayfirst(A)
is = Symbol[]
foreach(enumerate(inds)) do (d,ex)
isconst(ex) && return
isconst(ex) && return # ?? now that saveconstraints() returns inds, you could do dollars here, is that better? You still need to catch any outside of indexing.

if ex in (:(..), CART) # broadcasting!
d == length(inds) || error("can only use .. for broadcasting except after explicit indices")
ends = repeat([:end], d-1)
vex = :(@view $A1[$(ends...),$..])
if right
push!(store.rightcast, vex) # will be used to compute range of CART
push!(store.broadarrays, A1)
inds[d] = Symbol(IND, A1)
else
store.leftcast = vex
end
boolA, firstA = Symbol(BOOL, A1), Symbol(FIRST, A1)
push!(store.axisdefs, :(local $boolA, $firstA = $newindexer($vex))) # done inside maker
push!(store.broadargs, boolA, firstA) # pass all from maker to actor as tuple BROAD
return # editing of RHS must happen later
end

range_i, i = range_expr_walk(length(inds)==1 ? :(eachindex($A1)) : :(axes($A1,$d)), ex)
if i isa Symbol
push!(is, i)
Expand All @@ -309,6 +333,7 @@ saveconstraints(A, inds, store, right=true) = begin
push!(store.shiftedind, i...)
push!(store.pairconstraints, (i..., dollarstrip.(range_i)...))
end

end
if right
append!(store.rightind, is)
Expand All @@ -323,13 +348,18 @@ saveconstraints(A, inds, store, right=true) = begin
append!(store.leftind, is) # why can's this be the only path for store.leftind??
end
n = length(inds)
if n==1
if !isempty(store.rightcast)
str = "expected a $A1 to have at least $(n-1) dimensions"
push!(store.outpre, :( ndims($A1) >= $(n-1) || error($str) ))
elseif n==1
str = "expected a 1-array $A1, or a tuple"
push!(store.outpre, :( $A1 isa Tuple || ndims($A1) == 1 || error($str) ))
else
str = "expected a $n-array $A1" # already arrayfirst(A)
push!(store.outpre, :( ndims($A1) == $n || error($str) ))
end

inds
end

arrayfirst(A::Symbol) = A # this is for axes(A,d), axes(first(B),d), etc.
Expand All @@ -352,7 +382,7 @@ dollarwalk(store) = ex -> begin
@nospecialize ex
ex isa Expr || return ex
if ex.head == :call
ex.args[1] == :* && ex.args[2] === Int(0) && return false # tidy up dummy arrays!
ex.args[1] == :* && ex.args[2] === Int(0) && return false # tidy up dummy arrays! ?? these were needed before explicit ranges, can delete
callcost(ex.args[1], store) # cost model for threading
elseif ex.head == :$ # interpolation of $c things:
ex.args[1] isa Symbol || error("you can only interpolate single symbols, not $ex")
Expand All @@ -378,6 +408,8 @@ tidyleftraw(leftraw, store) = map(leftraw) do i
end
elseif i === :_
return 1
elseif i === :(..) # broadcasting!
return CART # This symbol ends up in leftind, which is good. But for in-pace, it's too early, will miss saveconstraints... unless that looks for CART too...
end
i
end
Expand Down Expand Up @@ -431,6 +463,7 @@ end
function index_ranges(store)

todo = Set(vcat(store.leftind, store.redind))
pop!(todo, CART, nothing)

for (i,j,r_i,r_j) in store.pairconstraints
if haskey(store.constraints, i) # && i in todo ??
Expand All @@ -455,6 +488,34 @@ function index_ranges(store)
end
end

if isempty(store.rightcast) # no broadcasting, but make some trivial definitions
axs = Symbol(CART, :axes)
push!(store.axisdefs, quote
local $axs = ()
local $BROAD = ()
end)
else # broadcasting!
axs = Symbol(CART, :axes) # this is the tuple of axes, also used for making a new array
carts = Symbol(AXIS, CART) # this is the CartesianIndex iterated over
push!(store.axisdefs, quote
local $axs = $combine_axes($(store.rightcast...))
local $carts = $CartesianIndices($axs)
local $BROAD = tuple($(store.broadargs...))
end)
if !(:newarray in store.flags)
push!(store.axisdefs, :($axs == $axes($(store.leftcast)) ||
throw(DimensionMismatch(("LHS does not match broadcast dimensions from RHS")))))
end
# Now also deal with RHS, where we must use CART + broadargs to calculate IND_A
rightex = map(store.broadarrays) do A
indA = Symbol(IND, A)
boolA, firstA = Symbol(BOOL, A), Symbol(FIRST, A)
:($indA = $newindex($CART, $boolA, $firstA))
end
store.right = :($(rightex...); $(store.right))

end

append!(store.outex, store.axisdefs)
end

Expand Down Expand Up @@ -500,6 +561,7 @@ function output_array(store)
# This now checks for OffsetArrays, and allows A[i,1] := ...
outaxes = map(store.leftraw) do i
i isa Integer && i==1 && return :(Base.OneTo(1))
i == CART && return :($(Symbol(CART, :axes))...)
i isa Symbol && return Symbol(AXIS, i)
error("can't use index $i on LHS for a new array")
end
Expand All @@ -515,7 +577,7 @@ function output_array(store)

simex = if isempty(store.arrays)
# :( zeros($TYP, tuple($(outaxes...))) ) # Array{T} doesn't accept ranges... but zero() doesn't accept things like @tullio [i,j] := (i,j) i ∈ 2:3, j ∈ 4:5
:( similar([], $TYP, tuple($(outaxes...))) )
:( similar([], $TYP, tuple($(outaxes...)),) )
else
:( similar($(store.arrays[1]), $TYP, tuple($(outaxes...),)) )
end
Expand Down Expand Up @@ -561,10 +623,14 @@ function action_functions(store)
push!(store.outeval, quote
function $make($(store.arrays...), $(store.scalars...), )
$sofar
$threader($act!, $ST, $(store.leftarray),
tuple($(store.arrays...), $(store.scalars...),),
tuple($(axisleft...),), tuple($(axisred...),);
block=$block, keep=$keep)
# $threader($act!, $ST, $(store.leftarray),
# tuple($(store.arrays...), $(store.scalars...),),
# tuple($(axisleft...),), tuple($(axisred...),); # missing BROAD
# block=$block, keep=$keep)
$act!($ST, $(store.leftarray),
$(store.arrays...), $(store.scalars...),
$(axisleft...), $(axisred...), $BROAD,
$keep)
return $(store.leftarray)
end
end)
Expand Down Expand Up @@ -636,15 +702,37 @@ function action_functions(store)
store.threads==true ? (BLOCK[] ÷ store.cost) :
store.threads
push!(store.outex, quote
$threader($act!, $ST, $(store.leftarray),
tuple($(store.arrays...), $(store.scalars...),),
tuple($(axisleft...),), tuple($(axisred...),);
block = $block, keep = $keep)
# $threader($act!, $ST, $(store.leftarray),
# tuple($(store.arrays...), $(store.scalars...),),
# tuple($(axisleft...),), tuple($(axisred...),);
# block = $block, keep = $keep)
$act!($ST, $(store.leftarray),
$(store.arrays...), $(store.scalars...),
$(axisleft...), $(axisred...), $BROAD,
$keep)
$(store.leftarray)
end)
end
end

#=
I don't like constructing threader($act.. twice

I'd like to only @eval if doing gradients

Order:
* without grads, act! should be defined before make. Or perhaps there is no make function then.
* with grads, both make and act! should be defined before @adjoint etc.


First I add stuff to make a new matrix
If I need make(), I can slurp that up & make a function

Once I add act!, I can no longer slurp.
Function act! should come first, but if it's not a function, then later...

=#


"""
make_many_actors(f!, args, ex1, [:i,], ex3, [:k,], ex5, ex6, store)
Expand All @@ -653,7 +741,7 @@ This makes several functions of this form,
decorated as necessary with `@inbouds` or `@avx` etc,
and with appropriate `storage_type` as the first argument.
```
f!(::Type, args..., keep=nothing) where {T}
f!(::Type, args..., broad=(), keep=nothing) where {T}
ex1
ex2 = (for i in axis_i
ex3
Expand All @@ -667,12 +755,17 @@ end
"""
function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex5, ex6, store)

if !isempty(store.broadargs)
bex = :(($(store.broadargs...),) = $BROAD) # broadcasting! unpack the extra arguments
ex1 = :($bex; $ex1)
end

ex4 = recurseloops(ex5, inner)
ex2 = recurseloops(:($ex3; $ex4; $ex6), outer)

push!(store.outeval, quote

function $act!(::Type, $(args...), $KEEP=nothing) where {$TYP}
function $act!(::Type, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP}
@debug "base actor:" typeof.(tuple($(args...)))
@inbounds @fastmath ($ex1; $ex2)
end
Expand All @@ -690,7 +783,7 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex
unroll = store.avx == true ? 0 : store.avx # unroll=0 is the default setting
push!(store.outeval, quote

function $act!(::Type{<:Array{<:Union{Base.HWReal, Bool}}}, $(args...), $KEEP=nothing) where {$TYP}
function $act!(::Type{<:Array{<:Union{Base.HWReal, Bool}}}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP}
@debug "LoopVectorization @avx actor, unroll=$unroll"
$expre
LoopVectorization.@avx unroll=$unroll $exloop
Expand All @@ -711,12 +804,12 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex
sizes = map(ax -> :(length($ax)), axouter)
push!(store.outeval, quote

KernelAbstractions.@kernel function $kernel($(args...), $KEEP) where {$TYP}
KernelAbstractions.@kernel function $kernel($(args...), $BROAD, $KEEP) where {$TYP}
($(outer...),) = @index(Global, NTuple)
($ex1; $ex3; $ex4; $ex6)
end

function $act!(::Type{<:CuArray}, $(args...), $KEEP=nothing) where {$TYP}
function $act!(::Type{<:CuArray}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP}
@debug "KernelAbstractions CuArrays actor"
cu_kern! = $kernel(CUDA(), $(store.cuda))
$(asserts...)
Expand All @@ -725,7 +818,7 @@ function make_many_actors(act!, args, ex1, outer::Vector, ex3, inner::Vector, ex
end

# Just for testing really...
function $act!(::Type{<:Array}, $(args...), $KEEP=nothing) where {$TYP}
function $act!(::Type{<:Array}, $(args...), $BROAD=(), $KEEP=nothing) where {$TYP}
@debug "KernelAbstractions CPU actor:" typeof.(tuple($(args...)))
cpu_kern! = $kernel(CPU(), Threads.nthreads())
$(asserts...)
Expand Down
2 changes: 2 additions & 0 deletions src/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Then it divides up the other axes, each accumulating in its own copy of `Z`.
`keep=nothing` means that it overwrites the array, anything else (`keep=true`) adds on.
"""
function threader(fun!::Function, T::Type, Z::AbstractArray, As::Tuple, I0s::Tuple, J0s::Tuple; block, keep=nothing)
return fun!(T, Z, As..., I0s..., J0s..., keep)
# not yet fixed up for broadcasting
Is = map(UnitRange, I0s)
Js = map(UnitRange, J0s)
if isnothing(block)
Expand Down
28 changes: 28 additions & 0 deletions test/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,34 @@ end

end

@testset "broadcasting" begin

f1(A) = @tullio C[i, ..] := A[i, ..] + 1
@test f1(ones(3)) == ones(3) .+ 1
@test f1(ones(3,4)) == ones(3,4) .+ 1
@test f1(ones(3,4,5)) == ones(3,4,5) .+ 1

f2(A) = @tullio C[i, ..] := A[i, k, ..]
@test f2(ones(3,4)) == fill(4.0, 3)
A3 = rand(3,4,5)
@test f2(A3) ≈ dropdims(sum(A3, dims=2), dims=2)

f3(A, B) = @tullio C[i,j, ..] := A[i, k, ..] * B[j, k, ..]
A2 = rand(3,3);
B2 = rand(3,3);
@test f3(A2, B2) ≈ A2 * B2'
A3 = rand(3,3,2);
B3 = rand(3,3,2);
C3 = f3(A3, B3)
@test C3[:,:,1] ≈ A3[:,:,1] * B3[:,:,1]'
@test C3[:,:,2] ≈ A3[:,:,2] * B3[:,:,2]'

C4 = f3(A3, B2)
@test C4[:,:,1] ≈ A3[:,:,1] * B2[:,:]'
@test C4[:,:,2] ≈ A3[:,:,2] * B2[:,:]'

end

@testset "without packages" begin

A = [i^2 for i in 1:10]
Expand Down