Skip to content

Commit 7db2ce5

Browse files
committed
Use ProgressLogging instead of Juno
1 parent 8d3b8d3 commit 7db2ce5

File tree

4 files changed

+6
-12
lines changed

4 files changed

+6
-12
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1111
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1212
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
14-
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1716
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1817
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1918
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2019
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
20+
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2323
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@@ -36,7 +36,6 @@ CUDA = "3"
3636
CodecZlib = "0.7"
3737
Colors = "0.12"
3838
Functors = "0.2.1"
39-
Juno = "0.8"
4039
MacroTools = "0.5"
4140
NNlib = "0.8"
4241
NNlibCUDA = "0.2"

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Flux
44

55
using Base: tail
66
using Statistics, Random, LinearAlgebra
7-
using Zygote, MacroTools, Juno, Reexport
7+
using Zygote, MacroTools, ProgressLogging, Reexport
88
using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd

src/data/tree.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ function Base.show(io::IO, t::Tree)
2121
print_tree(io, t)
2222
end
2323

24-
using Juno
25-
26-
@render Juno.Inline t::Tree begin
27-
render(t) = Juno.Tree(t.value, render.(t.children))
28-
Juno.Tree(typeof(t), [render(t)])
29-
end
30-
3124
Base.getindex(t::Tree, i::Integer) = t.children[i]
3225
Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...]
3326

src/optimise/train.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Juno
1+
using ProgressLogging: @withprogress, @logprogress
22
import Zygote: Params, gradient
33

44
"""
@@ -104,7 +104,8 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
104104
function train!(loss, ps, data, opt; cb = () -> ())
105105
ps = Params(ps)
106106
cb = runall(cb)
107-
@progress for d in data
107+
n = (Base.IteratorSize(typeof(data)) == Base.HasLength()) ? length(data) : 0
108+
@withprogress for (i, d) in enumerate(data)
108109
try
109110
gs = gradient(ps) do
110111
loss(batchmemaybe(d)...)
@@ -120,6 +121,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
120121
rethrow(ex)
121122
end
122123
end
124+
@logprogress i / n
123125
end
124126
end
125127

0 commit comments

Comments
 (0)