Skip to content

Commit

Permalink
faster default gradient performance
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Nov 12, 2018
1 parent 95fb460 commit b333120
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/tracker/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using MacroTools: @q, @forward

import Base: ==

export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!

tracker(x) = nothing

Expand Down Expand Up @@ -99,7 +100,8 @@ end

nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)

param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
Expand Down
24 changes: 12 additions & 12 deletions src/tracker/back.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
return
end

function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end

# Out-of-place gradients

struct Params
Expand Down Expand Up @@ -162,20 +171,11 @@ function losscheck(x)
isnan(x) && error("Loss is NaN")
end

function gradient(f, args...)
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end

derivative(f, x) = gradient(f, x)[1]

# Non-nesting versions

function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
6 changes: 3 additions & 3 deletions test/tracker.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
Expand Down Expand Up @@ -285,9 +285,9 @@ end
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end

Expand Down

0 comments on commit b333120

Please sign in to comment.