Skip to content

Commit 0b62a91

Browse files
authored
use Functors 0.3 in Flux (#2007)
* use Functors 0.3 * test which needs Functors 0.3 * fix show * rm useless test
1 parent 9aa123c commit 0b62a91

File tree

6 files changed

+28
-9
lines changed

6 files changed

+28
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Adapt = "3.0"
2929
ArrayInterface = "3.1, 4, 5, 6"
3030
CUDA = "3"
3131
ChainRulesCore = "1.12"
32-
Functors = "0.2.8"
32+
Functors = "0.3"
3333
MLUtils = "0.2"
3434
MacroTools = "0.5"
3535
NNlib = "0.8.9"

src/functor.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@ Possible values include:
3636
"""
3737
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
3838

39-
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
40-
4139
function params!(p::Params, x, seen = IdSet())
42-
x in seen && return
43-
push!(seen, x)
44-
for child in trainable(x)
45-
params!(p, child, seen)
40+
if x isa AbstractArray{<:Number} && Functors.isleaf(x)
41+
return push!(p, x)
42+
elseif x in seen
43+
nothing
44+
else
45+
push!(seen, x)
46+
for child in trainable(x)
47+
params!(p, child, seen)
48+
end
4649
end
4750
end
4851

src/layers/show.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ end
4747
_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
4848
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
4949
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
50-
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
50+
_show_leaflike(::Scale) = true # appears inside LayerNorm
51+
_show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays
5152

5253
_show_children(x) = trainable(x) # except for layers which hide their Tuple:
5354
_show_children(c::Chain) = c.layers

test/cuda/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ end
112112

113113
# Even more trivial: no movement
114114
@test gradient(x -> sum(abs, cpu(x)), a)[1] isa Matrix
115-
@test gradient(x -> sum(abs, cpu(x)), a')[1] isa LinearAlgebra.Adjoint
115+
@test gradient(x -> sum(abs, cpu(x)), a')[1] isa Matrix
116116
@test gradient(x -> sum(cpu(x)), a)[1] isa typeof(gradient(sum, a)[1]) # FillArray
117117
@test gradient(x -> sum(abs, gpu(x)), ca)[1] isa CuArray
118118
@test_skip gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray # KernelError: passing and using non-bitstype argument

test/layers/show.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ end
6868
@test Meta.isexpr(Meta.parse(toplevel_longchain), :call) # comments are ignored
6969
@test Meta.parse(toplevel_longchain).args[1] == :Chain
7070

71+
# Functors@0.3 marks transposed matrices non-leaf, shouldn't affect printing:
72+
adjoint_chain = repr("text/plain", Chain([Dense([1 2; 3 4]')]))
73+
@test occursin("Dense(2 => 2)", adjoint_chain)
74+
@test occursin("Chain([", adjoint_chain)
7175
end

test/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,17 @@ end
665665
end
666666
end
667667

668+
@testset "Shared parameters" begin
669+
mat = [1 2; 3 4.0]
670+
simple = ((nothing, mat, (3, mat, 4)))
671+
@test length(Flux.params(simple)) == 1
672+
673+
oneadj = (nt = (m = mat, a = mat'))
674+
@test length(Flux.params(oneadj)) == 1 # needs Functors@0.3
675+
676+
@test Flux.destructure(simple)[1] == Flux.destructure(oneadj)[1] == [1, 3, 2, 4]
677+
end
678+
668679
@testset "Various destructure bugs" begin
669680

670681
@testset "issue 1601" begin

0 commit comments

Comments
 (0)