Skip to content

Commit

Permalink
Merge pull request FluxML#1944 from darsnack/darsnack/progress-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Apr 14, 2022
2 parents 24f40b6 + f5ecde1 commit f97688e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ Multiple callbacks can be passed to `cb` as array.
"""
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
cb = runall(cb)
n = (Base.IteratorSize(typeof(data)) == Base.HasLength()) ? length(data) : 0
itrsz = Base.IteratorSize(typeof(data))
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
gs = gradient(ps) do
Expand All @@ -129,7 +130,7 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
rethrow(ex)
end
end
@logprogress i / n
@logprogress iszero(n) ? nothing : i / n
end
end

Expand Down

0 comments on commit f97688e

Please sign in to comment.