Skip to content

Commit

Permalink
Return all inputs, not just the final one
Browse files Browse the repository at this point in the history
Add pretty printing
  • Loading branch information
theabhirath committed Jun 1, 2022
1 parent 96a6448 commit 617ae2a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
9 changes: 7 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,19 +595,24 @@ end
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
for i in 1:N - 1])
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return $(y_symbols[N])))
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
return Expr(:block, calls...)
end

@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))

function Base.show(io::IO, m::PairwiseFusion)
print(io, "PairwiseFusion(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
Embedding(in => out; init=randn)
Expand Down
3 changes: 2 additions & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

for T in [
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Expand Down Expand Up @@ -53,6 +53,7 @@ _show_children(x) = trainable(x) # except for layers which hide their Tuple:
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
Expand Down

0 comments on commit 617ae2a

Please sign in to comment.