Skip to content

Commit

Permalink
break all the things
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes authored and staticfloat committed May 3, 2019
1 parent e991228 commit aa4d221
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 96 deletions.
14 changes: 14 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.1.2"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down Expand Up @@ -300,3 +306,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.1"

[[Zygote]]
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
git-tree-sha1 = "7fcb55117550e1c195a646947135cc9aac1e2afc"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.1.0+"
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
NNlib = "0.6"
Expand Down
4 changes: 1 addition & 3 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanP

@reexport using NNlib

using Tracker
using Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
using Zygote

include("optimise/Optimise.jl")
using .Optimise
Expand Down
30 changes: 1 addition & 29 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,33 +196,5 @@ end
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))

batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)

@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
41 changes: 15 additions & 26 deletions src/cuda/curnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ end
# Interface

import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims

Expand All @@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
return dst
end

CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}

function copyparams!(m::CuRNNs, d::RNNDesc)
Expand Down Expand Up @@ -267,37 +265,28 @@ function desc(rnn)
return d
end

import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
using Zygote: @adjoint

istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))

function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h)
return result[2], result[1]
end

function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h)
return result[2], result[1]
end

function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end

(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))

@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
Expand All @@ -309,7 +298,7 @@ end
end
end

@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
Expand Down
15 changes: 0 additions & 15 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,6 @@ end

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)

"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)

"""
reset!(rnn)
Expand Down
10 changes: 3 additions & 7 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ function argmax(xs...)
return onecold(xs...)
end

# Ambiguity hack

a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)

onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
# TODO probably still want this as a custom adjoint Zygote
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
13 changes: 2 additions & 11 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Juno
import Flux.Tracker: Params, gradient, data, update!
import Zygote: Params, gradient
import Base.depwarn

function update!(opt, x, x̄)
update!(x, -apply!(opt, x, data(x̄)))
update!(x, -apply!(opt, x, ))
end

function update!(opt, xs::Params, gs)
Expand All @@ -12,15 +12,6 @@ function update!(opt, xs::Params, gs)
end
end

# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end
end

# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
Expand Down
8 changes: 3 additions & 5 deletions src/treelike.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Adapt: adapt, adapt_storage
import .Tracker: IdSet
import .Zygote: IdSet

children(x) = ()
mapchildren(f, x) = x
Expand Down Expand Up @@ -39,7 +39,7 @@ end
function params(m)
ps = Params()
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
p isa AbstractArray{<:Real} &&
!any(p′ -> p′ === p, ps) && push!(ps, p),
m)
return ps
Expand Down Expand Up @@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m)

function mapparams(f, m)
mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
x isa Union{AbstractArray,Number} ? f(x) :
x
x isa Union{AbstractArray,Number} ? f(x) : x
end
end

0 comments on commit aa4d221

Please sign in to comment.