Skip to content

Commit

Permalink
Fix for progress logging to non-VS Code loggers
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Apr 13, 2022
1 parent f038cff commit f5ecde1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ Multiple callbacks can be passed to `cb` as array.
"""
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
cb = runall(cb)
n = (Base.IteratorSize(typeof(data)) == Base.HasLength()) ? length(data) : 0
itrsz = Base.IteratorSize(typeof(data))
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
gs = gradient(ps) do
Expand All @@ -129,7 +130,7 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
rethrow(ex)
end
end
@logprogress i / n
@logprogress iszero(n) ? nothing : i / n
end
end

Expand Down

0 comments on commit f5ecde1

Please sign in to comment.