Skip to content

Significant compile time latency in Flux with default optimization #1126

@Alexander-Barth

Description

@Alexander-Barth

I am trying to optimize a super-resolution GAN, and I see very significant compile time.
This issue is discussed here:
https://discourse.julialang.org/t/significant-compile-time-latency-in-flux-with-a-gan/68518

@ToucheSir was able to reduce this issue to the code below (thanks a lot!).

I am using Julia 1.6.3 on Linux and these packages:

(@v1.6) pkg> st Zygote ZygoteRules Flux
      Status `~/.julia/environments/v1.6/Project.toml`
  [587475ba] Flux v0.12.8 `https://github.com/FluxML/Flux.jl.git#master`
  [e88e6eb3] Zygote v0.6.30
  [700de1a5] ZygoteRules v0.2.2

It appears that this issue is sensitive to the optimization level and to the julia version used.

with julia 1.6.3 (default optimization):

 2.129568 seconds (6.12 M allocations: 352.760 MiB, 5.21% gc time, 99.96% compilation time)
  0.000323 seconds (1.80 k allocations: 138.938 KiB)
2740.221683 seconds (62.97 M allocations: 3.596 GiB, 0.05% gc time, 0.44% compilation time)
  0.004320 seconds (12.04 k allocations: 2.761 MiB)

Note the timing of the first call the loss_grad(lr_images, ps).

In julia 1.5.3 I get the following timing:

  1.969913 seconds (8.92 M allocations: 458.885 MiB, 6.31% gc time)
  0.000265 seconds (1.30 k allocations: 102.344 KiB)
 25.426727 seconds (57.88 M allocations: 2.927 GiB, 3.93% gc time)
  0.005641 seconds (14.71 k allocations: 1.047 MiB)

with "julia 1.6.3 -O1":

 1.389465 seconds (6.12 M allocations: 352.764 MiB, 25.61% gc time, 99.94% compilation time)
  0.000403 seconds (2.10 k allocations: 143.625 KiB)
 28.301099 seconds (62.99 M allocations: 3.596 GiB, 3.96% gc time, 26.40% compilation time)
  0.004889 seconds (12.94 k allocations: 2.775 MiB)
using Flux

channels = 4

function resblock(channels)
    return SkipConnection(Chain(
        Conv((3, 3), channels => channels, pad=1),
        Conv((3, 3), channels => channels, pad=1),
    ), +)
end

model = Chain(
    SkipConnection(
        Chain(
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
        ),
    +),
    AdaptiveMeanPool((1, 1))
)

display(model)
println()
@show typeof(model) 

loss(x) = sum(model(x))

lr_images = randn(Float32, 2, 2, channels, 1)
@time loss(lr_images)
@time loss(lr_images) 
    
loss_grad(x, ps) = gradient(() -> loss(x), ps)
    
ps = Flux.params(model)
@time loss_grad(lr_images, ps)
@time loss_grad(lr_images, ps)

This issue might be related
#1119

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions