Skip to content

Commit

Permalink
Remove _parent
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Apr 5, 2022
1 parent c831955 commit 0d55c00
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
5 changes: 1 addition & 4 deletions src/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
copyto!(dst, src)
end

_parent(x) = x
_parent(x::AbstractArray) = parent(x)

_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
Expand Down Expand Up @@ -79,7 +76,7 @@ function loadmodel!(dst, src; cache = Base.IdSet())

err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if _parent(ldst) in cache # we already loaded this parameter before
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
Expand Down
19 changes: 7 additions & 12 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,21 +483,16 @@ end
@test chain2[2].bias != chain1[2].bias

# test shared weights
encoder_dst = Chain(Dense(10 => 5), Dense(5 => 2))
decoder_dst = Chain(Dense(transpose(encoder_dst[2].weight)),
Dense(permutedims(encoder_dst[1].weight)))
encoder_src = Chain(Dense(10 => 5), Dense(5 => 2))
decoder_src = Chain(Dense(transpose(encoder_src[2].weight)),
Dense(5 => 10))
shared_dst = Dense(10 => 10)
shared_src = Dense(10 => 10)
# matched weights are okay
m1 = Chain(encoder_dst, decoder_dst)
m2 = Chain(encoder_src, decoder_src)
m1 = Chain(shared_dst, Dense(shared_dst.weight))
m2 = Chain(shared_src, Dense(shared_src.weight))
loadmodel!(m1, m2)
@test m1[1][2].weight === parent(m1[2][1].weight)
@test m1[1][1].weight == m2[1][1].weight
@test m1[1][1].weight != permutedims(m1[2][2].weight)
@test m1[1].weight === m1[2].weight
@test m1[1].weight == m2[2].weight
# mismatched weights are an error
m2 = Chain(Chain(Dense(10 => 5), Dense(5 => 2)), Chain(Dense(2 => 5), Dense(5 => 10)))
m2 = Chain(Dense(10 => 10), Dense(10 => 10))
@test_throws ErrorException loadmodel!(m1, m2)
# loading into tied weights with absent parameter is okay when the dst == zero
b = Flux.zeros32(5)
Expand Down

0 comments on commit 0d55c00

Please sign in to comment.