Skip to content

Commit 81eea84

Browse files
authored
Use destructure from Optimisers.jl (#1901)
* rm destructure * try to fix Downstream.yml by copying NNlib * Optimisers 0.2.1 * rm trainable fallback defn * more tests * test no longer broken * enlarge downstream for now * revert steps for downstream testing
1 parent ed78e8a commit 81eea84

File tree

5 files changed

+95
-62
lines changed

5 files changed

+95
-62
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ MLUtils = "0.2"
3434
MacroTools = "0.5"
3535
NNlib = "0.8.2"
3636
NNlibCUDA = "0.2"
37-
Optimisers = "0.2"
37+
Optimisers = "0.2.1"
3838
ProgressLogging = "0.1"
3939
Reexport = "0.2, 1.0"
4040
SpecialFunctions = "1.8.2, 2.1.2"

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MacroTools: @forward
77

88
@reexport using NNlib
99
using MLUtils
10-
import Optimisers: trainable # before v0.13, Flux owned this function
10+
import Optimisers: trainable, destructure # before v0.13, Flux owned these functions
1111

1212
using Zygote, ChainRulesCore
1313
using Zygote: Params, @adjoint, gradient, pullback, @nograd

src/functor.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ using Zygote: IdSet
44
import Functors: Functors, @functor, functor, fmap, isleaf
55
using SparseArrays: AbstractSparseArray
66

7-
trainable(m) = functor(m)[1]
8-
97
"""
108
testmode!(m, mode = true)
119

src/utils.jl

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -475,59 +475,6 @@ function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer.
475475
bias
476476
end
477477

478-
# Flattening models to weight vectors, and back
479-
480-
function _restructure(m, xs)
481-
i = 0
482-
= fmap(m) do x
483-
x isa AbstractArray || return x
484-
x = reshape(xs[i.+(1:length(x))], size(x))
485-
i += length(x)
486-
return x
487-
end
488-
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
489-
return
490-
end
491-
492-
@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
493-
m̄, numel = _restructure(m, xs), length(xs)
494-
function _restructure_pullback(dm)
495-
xs′ = destructure(dm)[1]
496-
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
497-
return (nothing, xs′)
498-
end
499-
return m̄, _restructure_pullback
500-
end
501-
502-
"""
503-
destructure(m)
504-
505-
Flatten a model's parameters into a single weight vector.
506-
507-
julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax)
508-
Chain(Dense(10, 5, std), Dense(5, 2), softmax)
509-
510-
julia> θ, re = destructure(m);
511-
512-
julia> θ
513-
67-element Vector{Float32}:
514-
-0.1407104
515-
...
516-
517-
The second return value `re` allows you to reconstruct the original network after making
518-
modifications to the weight vector (for example, with a hypernetwork).
519-
520-
julia> re(θ .* 2)
521-
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
522-
"""
523-
function destructure(m)
524-
xs = Zygote.Buffer([])
525-
fmap(m) do x
526-
x isa AbstractArray && push!(xs, x)
527-
return x
528-
end
529-
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
530-
end
531478

532479
# Other
533480

test/utils.jl

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,7 @@ end
390390
∇m = gradient(m -> sum(m(x)), m)[1]
391391
p, re = destructure(m)
392392
∇p = gradient-> sum(re(θ)(x)), p)[1]
393-
if VERSION >= v"1.7"
394-
@test_broken ∇p destructure(∇m)[1]
395-
else
396-
@test ∇p destructure(∇m)[1]
397-
end
393+
@test ∇p destructure(∇m)[1]
398394
end
399395
end
400396
end
@@ -538,3 +534,95 @@ end
538534
@test n_iter == 3
539535
end
540536
end
537+
538+
@testset "Various destructure bugs" begin
539+
540+
@testset "issue 1601" begin
541+
struct TwoDenses
542+
dense::Dense
543+
dense2::Dense
544+
end
545+
Flux.@functor TwoDenses
546+
547+
function (m::TwoDenses)(x)
548+
out = m.dense(x)
549+
end
550+
551+
model = TwoDenses(
552+
Dense(3,1),
553+
Dense(3,2)
554+
)
555+
p, re = Flux.destructure(model)
556+
557+
x = [1., 2., 3.]
558+
y, back = Flux.Zygote.pullback((x, p) -> re(p)(x), x, p)
559+
560+
dy = [4.]
561+
dx, dp = back(dy)
562+
@test length(p) == length(dp)
563+
end
564+
565+
@testset "issue 1727" begin
566+
p, re = Flux.destructure(BatchNorm(3)) # 6 parameters, plus 6 non-trainable
567+
@test length(p) == 6
568+
569+
x = rand(Float32, 3, 4)
570+
y, back = Flux.pullback(x, p) do x, p
571+
vec(re(p)(x))
572+
end
573+
@test_nowarn back(y)
574+
b = back(y)
575+
576+
@test size(b[1]) == size(x)
577+
@test size(b[2]) == size(p)
578+
end
579+
580+
@testset "issue 1767" begin
581+
struct Model{A}
582+
a::A
583+
b::A
584+
end
585+
Flux.@functor Model
586+
(m::Model)(x) = m.a(x) .+ m.b(x)
587+
588+
d = Dense(1, 1)
589+
x = rand(Float32, 1, 1)
590+
591+
# Sharing the parameters
592+
model = Model(d, d)
593+
594+
# Works
595+
g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model))
596+
597+
p, re = Flux.destructure(model)
598+
# Fails
599+
g2 = Flux.gradient(p -> sum(re(p)(x)), p)
600+
601+
@test g2[1] vcat(g1[d.weight], g1[d.bias])
602+
end
603+
604+
@testset "issue 1826" begin
605+
struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
606+
paths::T
607+
end
608+
Split(paths...) = Split(paths)
609+
Flux.@functor Split
610+
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
611+
612+
n_input, n_batch, n_shared = 5, 13, 11
613+
n_outputs = [3, 7]
614+
615+
data = rand(Float32, n_input, n_batch)
616+
model = Chain(
617+
Dense(n_input, n_shared),
618+
Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2]))
619+
)
620+
621+
pvec, re = Flux.destructure(model)
622+
loss(x, idx, pv) = sum(abs2, re(pv)(x)[idx]) # loss wrt `idx`th output term
623+
624+
g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec)
625+
@test g Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1]
626+
end
627+
628+
end

0 commit comments

Comments
 (0)