Skip to content

Commit

Permalink
Merge FluxML#1524
Browse files Browse the repository at this point in the history
1524: Add identity_init r=darsnack a=DrChainsaw

As discussed in FluxML#1431 (comment)

Since the only information available is the dimensions, it makes the following assumptions:

*  1D: A `Vector` of `zeros` (assumes bias)
*  2D: An identity matrix (assumes matrix multiplication)
*  2+D: A diagnoal matrix of identity kernels (assumes convolution) 

Not sure if there is a better description of the last one (thats just how I envision it).

I don't think it is possible to create identity mappings for all layer types and if it is, then doing so probably requires that weight init gets to know the layer type. The ones which will not have an ID mapping with this weight init are the RNNs, normalization layers (don't need it) and DepthwiseConv.

### PR Checklist

- [X ] Tests are added
- [X ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: DrChainsaw <[email protected]>
Co-authored-by: DrChainsaw <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
4 people authored Mar 6, 2021
2 parents 69e2198 + 3f5045b commit e0c8c6b
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 0 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* Add [identity_init](https://github.com/FluxML/Flux.jl/pull/1524).
* Add [Orthogonal Matrix initialization](https://github.com/FluxML/Flux.jl/pull/1496) as described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks](https://arxiv.org/abs/1312.6120).
* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
Expand Down
85 changes: 85 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,91 @@ end
sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

"""
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)
Return an `Array` of size `dims` which yields an identity mapping when used as parameters in
most Flux layers. Use `gain` to scale the identity by a constant.
Often useful in the context of transfer learning, i.e when one wants to add more capacity to
a model but start from the same mapping.
Use `shift` (integer or tuple) to apply circular shift to the output.
Equivalent to `Base.circshift(identity(dims...), shift)`.
Some caveats: Not all layers will be identity mapping when used with this init. Exceptions
include recurrent layers, `DepthwiseConv` and normalization layers.
Also note that layers must have `input_size == output_size` for identity mapping to be
possible. When this is not the case, extra dimensions of the array are padded with zeros.
For convolutional layers, in addition to the above, the kernel sizes must also be odd and
padding must be applied so that output feature maps have the same size as input feature maps,
e.g by using [`SamePad`](@ref).
Has the following behaviour
* 1D: A `Vector` of `zeros` (useful for an identity bias)
* 2D: An identity matrix (useful for an identity matrix multiplication)
* More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution)
```jldoctest
julia> Flux.identity_init(3,3)
3×3 Array{Float32,2}:
1.0 0.0 0.0
0.0 1.0 0.0
0.0 0.0 1.0
julia> Flux.identity_init(3,5)
3×5 Array{Float32,2}:
1.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0
julia> Flux.identity_init(3,3,2,2)
3×3×2×2 Array{Float32,4}:
[:, :, 1, 1] =
0.0 0.0 0.0
0.0 1.0 0.0
0.0 0.0 0.0
[:, :, 2, 1] =
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
[:, :, 1, 2] =
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
[:, :, 2, 2] =
0.0 0.0 0.0
0.0 1.0 0.0
0.0 0.0 0.0
```
"""
# Assume bias
identity_init(cols; gain=1, shift=0) = zeros(Float32, cols)

# Assume matrix multiplication
identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)

# Assume convolution
function identity_init(dims...; gain=1, shift=0)
nin, nout = dims[end-1], dims[end]
centers = map(d -> cld(d, 2), dims[1:end-2])
weights = zeros(Float32, dims)
for i in 1:min(nin,nout)
weights[centers..., i, i] = gain
end
return circshift(weights, shift)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)


ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)

Expand Down
48 changes: 48 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,54 @@ end
@test maximum(partial_si(8, 8)) == 0
@test maximum(partial_si(8, 8, sparsity=0)) > 0
end

@testset "identity_init" begin
import Flux: identity_init

@testset "Basic" begin
partial = identity_init(gain=3)
@test partial(3, 3) == identity_init(3, 3; gain=3) == [3 0 0; 0 3 0; 0 0 3]
end

@testset "Non-identity sizes" begin
@test identity_init(2, 3)[:, end] == zeros(Float32, 2)
@test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2)
@test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3)
@test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3)
@test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3)
end

@testset "Dense ID mapping" begin
l = Dense(3,3, initW = identity_init)

indata = reshape(collect(Float32, 1:9), 3, 3)
@test l(indata) == indata
end

@testset "$layer ID mapping with kernelsize $kernelsize" for layer in (Conv, ConvTranspose, CrossCor), kernelsize in (
(1,),
(3,),
(1, 3),
(3, 5),
(3, 5, 7))
nch = 3
l = layer(kernelsize, nch=>nch, init=identity_init, pad=SamePad())

indata = randn(Float32, kernelsize..., nch, nch)
@test l(indata) == indata
end

@testset "Inception identity" begin
insize = 7
path1 = Conv((1, 3), insize=>2; init=identity_init, pad=SamePad())
path2 = Conv((3, 5), insize=>3; init=identity_init(shift=(0, 0, 2, 0)), pad=SamePad())
path3 = Conv((5, 7), insize=>2; init=identity_init(shift=(0, 0, 5, 0)), pad=SamePad())
block = Parallel((xs...) -> cat(xs...;dims=3), path1, path2, path3)

indata = randn(Float32, 9, 9, 7, 2)
@test block(indata) == indata
end
end
end

@testset "Params" begin
Expand Down

0 comments on commit e0c8c6b

Please sign in to comment.