-
-
Notifications
You must be signed in to change notification settings - Fork 217
Closed
Labels
Description
I am trying to optimize a super-resolution GAN, and I see very significant compile time.
This issue is discussed here:
https://discourse.julialang.org/t/significant-compile-time-latency-in-flux-with-a-gan/68518
@ToucheSir was able to reduce this issue to the code below (thanks a lot!).
I am using Julia 1.6.3 on Linux and these packages:
(@v1.6) pkg> st Zygote ZygoteRules Flux
Status `~/.julia/environments/v1.6/Project.toml`
[587475ba] Flux v0.12.8 `https://github.com/FluxML/Flux.jl.git#master`
[e88e6eb3] Zygote v0.6.30
[700de1a5] ZygoteRules v0.2.2
It appears that this issue is sensitive to the optimization level and to the julia version used.
with julia 1.6.3 (default optimization):
2.129568 seconds (6.12 M allocations: 352.760 MiB, 5.21% gc time, 99.96% compilation time)
0.000323 seconds (1.80 k allocations: 138.938 KiB)
2740.221683 seconds (62.97 M allocations: 3.596 GiB, 0.05% gc time, 0.44% compilation time)
0.004320 seconds (12.04 k allocations: 2.761 MiB)
Note the timing of the first call the loss_grad(lr_images, ps)
.
In julia 1.5.3 I get the following timing:
1.969913 seconds (8.92 M allocations: 458.885 MiB, 6.31% gc time)
0.000265 seconds (1.30 k allocations: 102.344 KiB)
25.426727 seconds (57.88 M allocations: 2.927 GiB, 3.93% gc time)
0.005641 seconds (14.71 k allocations: 1.047 MiB)
with "julia 1.6.3 -O1":
1.389465 seconds (6.12 M allocations: 352.764 MiB, 25.61% gc time, 99.94% compilation time)
0.000403 seconds (2.10 k allocations: 143.625 KiB)
28.301099 seconds (62.99 M allocations: 3.596 GiB, 3.96% gc time, 26.40% compilation time)
0.004889 seconds (12.94 k allocations: 2.775 MiB)
using Flux
channels = 4
function resblock(channels)
return SkipConnection(Chain(
Conv((3, 3), channels => channels, pad=1),
Conv((3, 3), channels => channels, pad=1),
), +)
end
model = Chain(
SkipConnection(
Chain(
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
resblock(channels),
),
+),
AdaptiveMeanPool((1, 1))
)
display(model)
println()
@show typeof(model)
loss(x) = sum(model(x))
lr_images = randn(Float32, 2, 2, channels, 1)
@time loss(lr_images)
@time loss(lr_images)
loss_grad(x, ps) = gradient(() -> loss(x), ps)
ps = Flux.params(model)
@time loss_grad(lr_images, ps)
@time loss_grad(lr_images, ps)
This issue might be related
#1119