Skip to content

Commit 977e357

Browse files
authored
Merge pull request #43 from darsnack/walks
2 parents 981c866 + ffd6ae9 commit 977e357

File tree

8 files changed

+222
-86
lines changed

8 files changed

+222
-86
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ julia = "1.6"
1212

1313
[extras]
1414
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[targets]
19-
test = ["Test", "Documenter", "Zygote"]
20+
test = ["Test", "Documenter", "StaticArrays", "Zygote"]

docs/src/api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ Functors.children
99
Functors.isleaf
1010
```
1111

12+
```@docs
13+
Functors.AbstractWalk
14+
Functors.DefaultWalk
15+
Functors.StructuralWalk
16+
Functors.ExcludeWalk
17+
Functors.CachedWalk
18+
Functors.CollectWalk
19+
Functors.AnonymousWalk
20+
```
21+
1222
```@docs
1323
Functors.fmapstructure
1424
Functors.fcollect

src/Functors.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module Functors
33
export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect
44

55
include("functor.jl")
6+
include("walks.jl")
7+
include("maps.jl")
68
include("base.jl")
79

810
###
@@ -102,7 +104,8 @@ Equivalent to `functor(x)[1]`.
102104
children
103105

104106
"""
105-
fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk)
107+
fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk()[, prune])
108+
fmap(walk, f, x, ys...)
106109
107110
A structure and type preserving `map`.
108111
@@ -176,12 +179,23 @@ Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))
176179
To recurse into custom types without reconstructing them afterwards,
177180
use [`fmapstructure`](@ref).
178181
179-
For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`.
180-
This function walks (maps) over `xs` calling the continuation `f'` to continue traversal.
182+
For advanced customization of the traversal behaviour,
183+
pass a custom `walk` function that subtypes [`Functors.AbstractWalk`](ref).
184+
The form `fmap(walk, f, x, ys...)` can be called for custom walks.
185+
The simpler form `fmap(f, x, ys...; walk = mywalk)` will wrap `mywalk` in
186+
[`ExcludeWalk`](@ref) then [`CachedWalk`](@ref).
181187
182188
```jldoctest withfoo
183-
julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
184-
Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
189+
julia> struct MyWalk <: Functors.AbstractWalk end
190+
191+
julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
192+
Functors.DefaultWalk()(recurse, x)
193+
194+
julia> fmap(x -> 10x, m; walk = MyWalk())
195+
Foo("hello", (40, 50, "hello"))
196+
197+
julia> fmap(MyWalk(), x -> 10x, m)
198+
Foo("hello", (4, 5, "hello"))
185199
```
186200
187201
The behaviour when the same node appears twice can be altered by giving a value

src/functor.jl

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -34,81 +34,6 @@ isleaf(x) = children(x) === ()
3434

3535
children(x) = functor(x)[1]
3636

37-
function _default_walk(f, x)
38-
func, re = functor(x)
39-
re(map(f, func))
40-
end
41-
42-
usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
43-
usecache(::Nothing, x) = false
44-
45-
@generated function anymutable(x::T) where {T}
46-
ismutabletype(T) && return true
47-
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
48-
return Expr(:(||), subs...)
49-
end
50-
51-
struct NoKeyword end
52-
53-
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
54-
if usecache(cache, x) && haskey(cache, x)
55-
return prune isa NoKeyword ? cache[x] : prune
56-
end
57-
ret = if exclude(x)
58-
f(x)
59-
else
60-
walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
61-
end
62-
if usecache(cache, x)
63-
cache[x] = ret
64-
end
65-
ret
66-
end
67-
68-
###
69-
### Extras
70-
###
71-
72-
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)
73-
74-
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
75-
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
76-
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
77-
# for the results, to preserve traversal order (important downstream!).
78-
x in cache && return output
79-
if !exclude(x)
80-
push!(cache, x)
81-
push!(output, x)
82-
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
83-
end
84-
return output
85-
end
86-
87-
###
88-
### Vararg forms
89-
###
90-
91-
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
92-
if usecache(cache, x) && haskey(cache, x)
93-
return prune isa NoKeyword ? cache[x] : prune
94-
end
95-
ret = if exclude(x)
96-
f(x, ys...)
97-
else
98-
walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
99-
end
100-
if usecache(cache, x)
101-
cache[x] = ret
102-
end
103-
ret
104-
end
105-
106-
function _default_walk(f, x, ys...)
107-
func, re = functor(x)
108-
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
109-
re(map(f, func, yfuncs...))
110-
end
111-
11237
###
11338
### FlexibleFunctors.jl
11439
###

src/maps.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
fmap(walk::AbstractWalk, f, x, ys...) = walk((xs...) -> fmap(walk, f, xs...), x, ys...)
2+
3+
function fmap(f, x, ys...; exclude = isleaf,
4+
walk = DefaultWalk(),
5+
cache = IdDict(),
6+
prune = NoKeyword())
7+
_walk = ExcludeWalk(AnonymousWalk(walk), f, exclude)
8+
if !isnothing(cache)
9+
_walk = CachedWalk(_walk, prune, cache)
10+
end
11+
fmap(_walk, f, x, ys...)
12+
end
13+
14+
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...)
15+
16+
fcollect(x; exclude = v -> false) =
17+
fmap(ExcludeWalk(CollectWalk(), _ -> nothing, exclude), _ -> nothing, x)

src/walks.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
AbstractWalk
3+
4+
Any walk for use with [`fmap`](@ref) should inherit from this type.
5+
A walk subtyping `AbstractWalk` must satisfy the walk function interface:
6+
```julia
7+
struct MyWalk <: AbstractWalk end
8+
9+
function (::MyWalk)(recurse, x, ys...)
10+
# implement this
11+
end
12+
```
13+
The walk function is called on a node `x` in a Functors tree.
14+
It may also be passed associated nodes `ys...` in other Functors trees.
15+
The walk function recurses further into `(x, ys...)` by calling
16+
`recurse` on the child nodes.
17+
The choice of which nodes to recurse and in what order is custom to the walk.
18+
"""
19+
abstract type AbstractWalk end
20+
21+
"""
22+
AnonymousWalk(walk_fn)
23+
24+
Wrap a `walk_fn` so that `AnonymousWalk(walk_fn) isa AbstractWalk`.
25+
This type only exists for backwards compatability and should be directly used.
26+
Attempting to wrap an existing `AbstractWalk` is a no-op (i.e. it is not wrapped).
27+
"""
28+
struct AnonymousWalk{F} <: AbstractWalk
29+
walk::F
30+
31+
function AnonymousWalk(walk::F) where F
32+
Base.depwarn("Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.", :AnonymousWalk)
33+
return new{F}(walk)
34+
end
35+
end
36+
# do not wrap an AbstractWalk
37+
AnonymousWalk(walk::AbstractWalk) = walk
38+
39+
(walk::AnonymousWalk)(recurse, x, ys...) = walk.walk(recurse, x, ys...)
40+
41+
"""
42+
DefaultWalk()
43+
44+
The default walk behavior for Functors.jl.
45+
Walks all the [`Functors.children`](@ref) of trees `(x, ys...)` based on
46+
the structure of `x`.
47+
The resulting mapped child nodes are restructured into the type of `x`.
48+
49+
See [`fmap`](@ref) for more information.
50+
"""
51+
struct DefaultWalk <: AbstractWalk end
52+
53+
function (::DefaultWalk)(recurse, x, ys...)
54+
func, re = functor(x)
55+
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
56+
re(map(recurse, func, yfuncs...))
57+
end
58+
59+
"""
60+
StructuralWalk()
61+
62+
A structural variant of [`Functors.DefaultWalk`](@ref).
63+
The recursion behavior is identical, but the mapped children are not restructured.
64+
65+
See [`fmapstructure`](@ref) for more information.
66+
"""
67+
struct StructuralWalk <: AbstractWalk end
68+
69+
(::StructuralWalk)(recurse, x) = map(recurse, children(x))
70+
71+
"""
72+
ExcludeWalk(walk, fn, exclude)
73+
74+
A walk that recurses nodes `(x, ys...)` according to `walk`,
75+
except when `exclude(x)` is true.
76+
Then, `fn(x, ys...)` is applied instead of recursing further.
77+
78+
Typically wraps an existing `walk` for use with [`fmap`](@ref).
79+
"""
80+
struct ExcludeWalk{T, F, G} <: AbstractWalk
81+
walk::T
82+
fn::F
83+
exclude::G
84+
end
85+
86+
(walk::ExcludeWalk)(recurse, x, ys...) =
87+
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)
88+
89+
struct NoKeyword end
90+
91+
usecache(::Union{AbstractDict, AbstractSet}, x) =
92+
isleaf(x) ? anymutable(x) : ismutable(x)
93+
usecache(::Nothing, x) = false
94+
95+
@generated function anymutable(x::T) where {T}
96+
ismutabletype(T) && return true
97+
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
98+
return Expr(:(||), subs...)
99+
end
100+
101+
"""
102+
CachedWalk(walk[; prune])
103+
104+
A walk that recurses nodes `(x, ys...)` according to `walk` and storing the
105+
output of the recursion in a cache indexed by `x` (based on object ID).
106+
Whenever the cache already contains `x`, either:
107+
- `prune` is specified, then it is returned, or
108+
- `prune` is unspecified, and the previously cached recursion of `(x, ys...)`
109+
returned.
110+
111+
Typically wraps an existing `walk` for use with [`fmap`](@ref).
112+
"""
113+
struct CachedWalk{T, S} <: AbstractWalk
114+
walk::T
115+
prune::S
116+
cache::IdDict{Any, Any}
117+
end
118+
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
119+
CachedWalk(walk, prune, cache)
120+
121+
function (walk::CachedWalk)(recurse, x, ys...)
122+
should_cache = usecache(walk.cache, x)
123+
if should_cache && haskey(walk.cache, x)
124+
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
125+
else
126+
ret = walk.walk(recurse, x, ys...)
127+
if should_cache
128+
walk.cache[x] = ret
129+
end
130+
return ret
131+
end
132+
end
133+
134+
"""
135+
CollectWalk()
136+
137+
A walk that recurses into a node `x` via [`Functors.children`](@ref),
138+
storing the recursion history in a cache.
139+
The resulting ordered recursion history is returned.
140+
141+
See [`fcollect`](@ref) for more information.
142+
"""
143+
struct CollectWalk <: AbstractWalk
144+
cache::Base.IdSet{Any}
145+
output::Vector{Any}
146+
end
147+
CollectWalk() = CollectWalk(Base.IdSet(), Any[])
148+
149+
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
150+
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
151+
# for the results, to preserve traversal order (important downstream!).
152+
function (walk::CollectWalk)(recurse, x)
153+
if usecache(walk.cache, x) && (x in walk.cache)
154+
return walk.output
155+
end
156+
# to exclude, we wrap this walk in ExcludeWalk
157+
usecache(walk.cache, x) && push!(walk.cache, x)
158+
push!(walk.output, x)
159+
map(recurse, children(x))
160+
161+
return walk.output
162+
end

test/basics.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ end
8585
@test m5f.x === m5f.y
8686
@test m5f.x !== m5f.z
8787

88-
@testset "usecache" begin
89-
d = IdDict()
90-
88+
@testset "usecache ($d)" for d in [IdDict(), Base.IdSet()]
9189
# Leaf types:
9290
@test usecache(d, [1,2])
9391
@test !usecache(d, 4.0)
@@ -101,9 +99,9 @@ end
10199

102100
@test !usecache(d, (5, [6.0]')) # contains mutable
103101
@test !usecache(d, (x = [1,2,3], y = 4))
104-
102+
105103
usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care?
106-
104+
107105
# No dictionary:
108106
@test !usecache(nothing, [1,2])
109107
@test !usecache(nothing, 3)
@@ -173,6 +171,14 @@ end
173171
m2 = [1, 2, 3]
174172
m3 = Foo(m1, m2)
175173
@test all(fcollect(m3) .=== [m3, m1, m2])
174+
175+
m1 = [1, 2, 3]
176+
m2 = SVector{length(m1)}(m1)
177+
m2′ = SVector{length(m1)}(m1)
178+
m3 = Foo(m1, m1)
179+
m4 = Foo(m2, m2′)
180+
@test all(fcollect(m3) .=== [m3, m1])
181+
@test all(fcollect(m4) .=== [m4, m2, m2′])
176182
end
177183

178184
###

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Functors, Test
22
using Zygote
33
using LinearAlgebra
4+
using StaticArrays
45

56
@testset "Functors.jl" begin
67

0 commit comments

Comments
 (0)