Skip to content

Commit 0e5356f

Browse files
authored
Functor Transpose et. al. (#33)
* simplest isbits usecache * apply the cache only to leaf nodes * functor Transpose, Adjoint, PermutedDimsArray * rm version check, etc * rm usecache, for another PR * tidy, tests
1 parent 014b7a3 commit 0e5356f

File tree

8 files changed

+130
-28
lines changed

8 files changed

+130
-28
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.0'
2120
- '1.6' # Replace this with the minimum Julia version that your package supports.
2221
- '1' # automatically expands to the latest stable 1.x release of Julia
2322
- 'nightly'

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
name = "Functors"
22
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.8"
4+
version = "0.3.0"
5+
6+
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
58

69
[compat]
710
Documenter = "0.27"
8-
julia = "1"
11+
julia = "1.6"
912

1013
[extras]
1114
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

src/Functors.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ usually using the macro [`@functor`](@ref).
2323
"""
2424
functor
2525

26-
@static if VERSION >= v"1.5" # var"@functor" doesn't work on 1.0, temporarily disable
2726
@doc """
2827
@functor T
2928
@functor T (x,)
@@ -66,7 +65,6 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560)
6665
```
6766
"""
6867
var"@functor"
69-
end # VERSION
7068

7169
"""
7270
Functors.isleaf(x)

src/base.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,38 @@
1+
12
@functor Base.RefValue
3+
4+
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)
5+
6+
###
7+
### Array wrappers
8+
###
9+
10+
using LinearAlgebra
11+
# The reason for these is to let W and W' be seen as tied weights in Flux models.
12+
# Can't treat ReshapedArray very well, as its type doesn't include enough details for reconstruction.
13+
14+
functor(::Type{<:Adjoint}, x) = (parent = _adjoint(x),), y -> adjoint(only(y))
15+
16+
_adjoint(x) = adjoint(x) # _adjoint is the inverse, and also understands more types:
17+
_adjoint(x::NamedTuple{(:parent,)}) = x.parent # "structural" gradient, and lazy broadcast used by Optimisers:
18+
_adjoint(bc::Broadcast.Broadcasted{S}) where S = Broadcast.Broadcasted{S}(_conjugate(bc.f, adjoint), _adjoint.(bc.args))
19+
20+
functor(::Type{<:Transpose}, x) = (parent = _transpose(x),), y -> transpose(only(y))
21+
22+
_transpose(x) = transpose(x)
23+
_transpose(x::NamedTuple{(:parent,)}) = x.parent
24+
_transpose(bc::Broadcast.Broadcasted{S}) where S = Broadcast.Broadcasted{S}(_conjugate(bc.f, transpose), _transpose.(bc.args))
25+
26+
_conjugate(f::F, ::typeof(identity)) where F = f
27+
_conjugate(f::F, op::Union{typeof(transpose), typeof(adjoint)}) where F = (xs...,) -> op(f(op.(xs)...))
28+
29+
function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x) where {T,N,perm,iperm}
30+
(parent = _PermutedDimsArray(x, iperm),), y -> PermutedDimsArray(only(y), perm)
31+
end
32+
function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x::PermutedDimsArray{Tx,N,perm,iperm}) where {T,Tx,N,perm,iperm}
33+
(parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping wrice.
34+
end
35+
36+
_PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm)
37+
_PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent
38+
_PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm)

src/functor.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty
88
functor(::Type{<:AbstractArray}, x) = x, y -> y
99
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
1010

11-
@static if VERSION >= v"1.6"
12-
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)
13-
end
14-
1511
function makefunctor(m::Module, T, fs = fieldnames(T))
1612
yᵢ = 0
1713
escargs = map(fieldnames(T)) do f

test/base.jl

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,86 @@
1-
@testset "Base" begin
2-
@testset "RefValue" begin
3-
x = Ref(1)
4-
p, re = Functors.functor(x)
5-
@test p == (x = 1,)
6-
@test re(p) isa Base.RefValue{Int}
7-
end
1+
2+
@testset "RefValue" begin
3+
@test fmap(sqrt, Ref(16))[] == 4.0
4+
@test fmap(sqrt, Ref(16)) isa Ref
5+
@test fmapstructure(sqrt, Ref(16)) === (x = 4.0,)
6+
7+
x = Ref(13)
8+
p, re = Functors.functor(x)
9+
@test p == (x = 13,)
10+
@test re(p) isa Base.RefValue{Int}
11+
12+
x2 = (a = x, b = [7, x, nothing], c = (7, nothing, Ref(13)))
13+
y2 = fmap(identity, x2)
14+
@test x2.a !== y2.a # it's a new Ref
15+
@test y2.a === y2.b[2] # relation is maintained
16+
@test y2.a !== y2.c[3] # no new relation created
17+
18+
x3 = Ref([3.14])
19+
f3 = [Foo(x3, x), x3, x]
20+
@test f3[1].x === f3[2]
21+
y3 = fmapstructure(identity, f3) # replaces mutable with immutable
22+
@test y3[1].x === y3[2]
23+
@test y3[1].x.x === y3[2].x
24+
z3 = fmapstructure(identity, y3)
25+
@test z3[1].x === z3[2]
26+
@test z3[1].x.x === z3[2].x
27+
end
28+
29+
@testset "ComposedFunction" begin
30+
f1 = Foo(1.1, 2.2)
31+
f2 = Bar(3.3)
32+
@test Functors.functor(f1 f2)[1] == (outer = f1, inner = f2)
33+
@test Functors.functor(f1 f2)[2]((outer = f1, inner = f2)) == f1 f2
34+
@test fmap(x -> x + 10, f1 f2) == Foo(11.1, 12.2) Bar(13.3)
35+
end
36+
37+
@testset "LinearAlgebra containers" begin
38+
@test fmapstructure(identity, [1,2,3]') == (parent = [1, 2, 3],)
39+
@test fmapstructure(identity, transpose([1,2,3])) == (parent = [1, 2, 3],)
40+
41+
CNT = Ref(0)
42+
fv(x::Vector) = (CNT[]+=1; 10v)
43+
44+
v = [1,2,3]
45+
nt = fmap(fv, (a=v, b=v', c=transpose(v), d=[1,2,3]'))
46+
47+
@test nt.a === adjoint(nt.b) # does not break tie
48+
@test nt.a === transpose(nt.c)
49+
50+
@test CNT[] == 2
51+
@test nt.a == adjoint(nt.d) # does not create a new tie
52+
@test nt.a !== adjoint(nt.d)
53+
54+
@test nt.b isa Adjoint
55+
@test nt.c isa Transpose
56+
57+
x = [1,2,3]'
58+
xs = fmapstructure(identity, x) # check it digests this, e.g. structural gradient representation
59+
@test Functors.functor(typeof(x), xs) == Functors.functor(x) # (no real need for [2] types to match)
60+
61+
x = transpose([1 2; 3 4])
62+
yt = transpose([5 6; 7 8])
63+
ym = Matrix(yt) # check it digests this, e.g. simplest Matrix gradient
64+
@test Functors.functor(typeof(x), yt)[1].parent == Functors.functor(typeof(x), ym)[1].parent
65+
66+
ybc = Broadcast.broadcasted(+, ym, 9) # check it digests this, as Optimisers.jl makes these
67+
collect(ybc) isa Vector
68+
zbc = Functors.functor(typeof(x), ybc)[1].parent
69+
@test zbc .+ 0 == Functors.functor(typeof(x), ym .+ 9)[1].parent
70+
71+
# Similar checks for Adjoint.
72+
x = adjoint([1 2im 3; 4im 5 6im])
73+
yt = adjoint([7im 8 9; 0 im 2])
74+
ym = Matrix(yt)
75+
@test Functors.functor(typeof(x), yt)[1].parent == Functors.functor(typeof(x), ym)[1].parent
76+
77+
ybc = Broadcast.broadcasted(+, ym, [11im, 12, im])
78+
collect(ybc) isa Vector
79+
zbc = Functors.functor(typeof(x), ybc)[1].parent
80+
@test zbc .+ 0 == Functors.functor(typeof(x), ym .+ [11im, 12, im])[1].parent
81+
end
82+
83+
@testset "PermutedDimsArray" begin
84+
@test fmapstructure(identity, PermutedDimsArray([1 2; 3 4], (2,1))) == (parent = [1 2; 3 4],)
85+
@test fmap(exp, PermutedDimsArray([1 2; 3 4], (2,1))) isa PermutedDimsArray{Float64}
886
end

test/basics.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,15 @@ using Functors: functor
44
struct Foo; x; y; end
55
@functor Foo
66

7-
struct Bar; x; end
7+
struct Bar{T}; x::T; end
88
@functor Bar
99

1010
struct OneChild3; x; y; z; end
1111
@functor OneChild3 (y,)
1212

1313
struct NoChildren2; x; y; end
1414

15-
@static if VERSION >= v"1.6"
16-
@testset "ComposedFunction" begin
17-
f1 = Foo(1.1, 2.2)
18-
f2 = Bar(3.3)
19-
@test Functors.functor(f1 f2)[1] == (outer = f1, inner = f2)
20-
@test Functors.functor(f1 f2)[2]((outer = f1, inner = f2)) == f1 f2
21-
@test fmap(x -> x + 10, f1 f2) == Foo(11.1, 12.2) Bar(13.3)
22-
end
23-
end
15+
struct NoChild{T}; x::T; end
2416

2517
###
2618
### Basic functionality
@@ -187,7 +179,6 @@ end
187179
end
188180
end
189181

190-
@static if VERSION >= v"1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
191182
@testset "old test update.jl" begin
192183
struct M{F,T,S}
193184
σ::F
@@ -211,7 +202,6 @@ end
211202
@test.W fill(0.8f0, size(m.W))
212203
@test.b fill(-0.2f0, size(m.b))
213204
end
214-
end # VERSION
215205

216206
###
217207
### FlexibleFunctors.jl

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Functors, Test
22
using Zygote
3+
using LinearAlgebra
34

45
@testset "Functors.jl" begin
56

0 commit comments

Comments
 (0)