Skip to content

Commit b5b872b

Browse files
Merge branch 'master' into dg/grad
2 parents 101f643 + d58d273 commit b5b872b

File tree

9 files changed

+145
-47
lines changed

9 files changed

+145
-47
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.DS_Store
22
/dev/
3+
Manifest.toml

Manifest.toml

Lines changed: 0 additions & 21 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
name = "Functors"
22
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.3"
5-
6-
[deps]
7-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
4+
version = "0.2.7"
85

96
[compat]
10-
MacroTools = "0.5"
117
julia = "1"
128

139
[extras]

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Functors.jl provides tools to express a powerful design pattern for dealing with
1717

1818
Functors.jl provides `fmap` to make those things easy, acting as a 'map over parameters':
1919

20-
```julia-repl
20+
```julia
2121
julia> using Functors
2222

2323
julia> struct Foo
@@ -36,7 +36,7 @@ Foo(1.0, [1.0, 2.0, 3.0])
3636

3737
It works also with deeply-nested models:
3838

39-
```julia-repl
39+
```julia
4040
julia> struct Bar
4141
x
4242
end
@@ -52,7 +52,7 @@ Bar(Foo(1.0, [1.0, 2.0, 3.0]))
5252

5353
The workhorse of `fmap` is actually a lower level function, `functor`:
5454

55-
```julia-repl
55+
```julia
5656
julia> xs, re = functor(Foo(1, [1, 2, 3]))
5757
((x = 1, y = [1, 2, 3]), var"#21#22"())
5858

@@ -64,7 +64,7 @@ Foo(1.0, [1.0, 2.0, 3.0])
6464

6565
To include only certain fields, pass a tuple of field names to `@functor`:
6666

67-
```julia-repl
67+
```julia
6868
julia> struct Baz
6969
x
7070
y
@@ -87,7 +87,7 @@ For a discussion regarding the need for a `cache` in the implementation of `fmap
8787

8888
Use `exclude` for more fine-grained control over whether `fmap` descends into a particular value (the default is `exclude = Functors.isleaf`):
8989

90-
```julia-repl
90+
```julia
9191
julia> using CUDA
9292

9393
julia> x = ['a', 'b', 'c'];

src/Functors.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
module Functors
22

3-
using MacroTools
4-
5-
export @functor, fmap, fmapstructure, fcollect
3+
export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect
64

75
include("functor.jl")
6+
include("base.jl")
87

98
end # module

src/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@functor Base.RefValue

src/functor.jl

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ functor(::Type{<:NamedTuple}, x) = x, y -> y
77
functor(::Type{<:AbstractArray}, x) = x, y -> y
88
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
99

10+
@static if VERSION >= v"1.6"
11+
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)
12+
end
13+
1014
function makefunctor(m::Module, T, fs = fieldnames(T))
1115
yᵢ = 0
1216
escargs = map(fieldnames(T)) do f
@@ -20,15 +24,42 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
2024
end
2125

2226
function functorm(T, fs = nothing)
23-
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
24-
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
27+
fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)")
28+
fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
2529
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
2630
end
2731

2832
macro functor(args...)
2933
functorm(args...)
3034
end
3135

36+
function makeflexiblefunctor(m::Module, T, pfield)
37+
pfield = QuoteNode(pfield)
38+
@eval m begin
39+
function $Functors.functor(::Type{<:$T}, x)
40+
pfields = getproperty(x, $pfield)
41+
function re(y)
42+
all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T))
43+
return $T(all_args...)
44+
end
45+
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
46+
return func, re
47+
end
48+
49+
end
50+
51+
end
52+
53+
function flexiblefunctorm(T, pfield = :params)
54+
pfield isa Symbol || error("@flexiblefunctor T param_field")
55+
pfield = QuoteNode(pfield)
56+
:(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield))))
57+
end
58+
59+
macro flexiblefunctor(args...)
60+
flexiblefunctorm(args...)
61+
end
62+
3263
"""
3364
isleaf(x)
3465
@@ -137,7 +168,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x))
137168
fcollect(x; exclude = v -> false)
138169
139170
Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
140-
and collecting the results into a flat array.
171+
and collecting the results into a flat array, ordered by a breadth-first
172+
traversal of `x`, respecting the iteration order of `children` calls.
141173
142174
Doesn't recurse inside branches rooted at nodes `v`
143175
for which `exclude(v) == true`.
@@ -180,13 +212,17 @@ julia> fcollect(m, exclude = v -> Functors.isleaf(v))
180212
Bar([1, 2, 3])
181213
```
182214
"""
183-
function fcollect(x; cache = [], exclude = v -> false)
184-
x in cache && return cache
185-
if !exclude(x)
186-
push!(cache, x)
187-
foreach(y -> fcollect(y; cache = cache, exclude = exclude), children(x))
188-
end
189-
return cache
215+
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
216+
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
217+
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
218+
# for the results, to preserve traversal order (important downstream!).
219+
x in cache && return output
220+
if !exclude(x)
221+
push!(cache, x)
222+
push!(output, x)
223+
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
224+
end
225+
return output
190226
end
191227

192228
# Allow gradients and other constructs that match the structure of the functor

test/base.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@testset "Base" begin
2+
@testset "RefValue" begin
3+
x = Ref(1)
4+
p, re = Functors.functor(x)
5+
@test p == (x = 1,)
6+
@test re(p) isa Base.RefValue{Int}
7+
end
8+
end

test/basics.jl

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Functors, Test
2-
31
struct Foo
42
x
53
y
@@ -23,6 +21,16 @@ struct NoChildren
2321
y
2422
end
2523

24+
@static if VERSION >= v"1.6"
25+
@testset "ComposedFunction" begin
26+
f1 = Foo(1.1, 2.2)
27+
f2 = Bar(3.3)
28+
@test Functors.functor(f1 f2)[1] == (outer = f1, inner = f2)
29+
@test Functors.functor(f1 f2)[2]((outer = f1, inner = f2)) == f1 f2
30+
@test fmap(x -> x + 10, f1 f2) == Foo(11.1, 12.2) Bar(13.3)
31+
end
32+
end
33+
2634
@testset "Nested" begin
2735
model = Bar(Foo(1, [1, 2, 3]))
2836

@@ -73,6 +81,76 @@ end
7381
m0 = NoChildren(:a, :b)
7482
m3 = Foo(m2, m0)
7583
m4 = Bar(m3)
76-
println(fcollect(m4))
7784
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
85+
86+
m1 = [1, 2, 3]
87+
m2 = [1, 2, 3]
88+
m3 = Foo(m1, m2)
89+
@test all(fcollect(m3) .=== [m3, m1, m2])
90+
end
91+
92+
struct FFoo
93+
x
94+
y
95+
p
96+
end
97+
@flexiblefunctor FFoo p
98+
99+
struct FBar
100+
x
101+
p
102+
end
103+
@flexiblefunctor FBar p
104+
105+
struct FBaz
106+
x
107+
y
108+
z
109+
p
110+
end
111+
@flexiblefunctor FBaz p
112+
113+
@testset "Flexible Nested" begin
114+
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))
115+
116+
model′ = fmap(float, model)
117+
118+
@test model.x.y == model′.x.y
119+
@test model′.x.y isa Vector{Float64}
120+
end
121+
122+
@testset "Flexible Walk" begin
123+
model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y))
124+
125+
model′ = fmapstructure(identity, model)
126+
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
127+
128+
model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,))
129+
130+
model2′ = fmapstructure(identity, model2)
131+
@test model2′ == (; x=(0, (; x=[1, 2, 3])))
132+
end
133+
134+
@testset "Flexible Property list" begin
135+
model = FBaz(1, 2, 3, (:x, :z))
136+
model′ = fmap(x -> 2x, model)
137+
138+
@test (model′.x, model′.y, model′.z) == (2, 2, 6)
139+
end
140+
141+
@testset "Flexible fcollect" begin
142+
m1 = 1
143+
m2 = [1, 2, 3]
144+
m3 = FFoo(m1, m2, (:y, ))
145+
m4 = FBar(m3, (:x,))
146+
@test all(fcollect(m4) .=== [m4, m3, m2])
147+
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
148+
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])
149+
150+
m0 = NoChildren(:a, :b)
151+
m1 = [1, 2, 3]
152+
m2 = FBar(m1, ())
153+
m3 = FFoo(m2, m0, (:x, :y,))
154+
m4 = FBar(m3, (:x,))
155+
@test all(fcollect(m4) .=== [m4, m3, m2, m0])
78156
end

0 commit comments

Comments
 (0)