Skip to content

Commit e051e21

Browse files
committed
Separate walks out from fmap
1 parent 981c866 commit e051e21

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

src/Functors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module Functors
33
export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect
44

55
include("functor.jl")
6+
include("walks.jl")
7+
include("maps.jl")
68
include("base.jl")
79

810
###

src/maps.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
fmap(walk::AbstractWalk, f, x, ys...) = walk((xs...) -> fmap(walk, f, xs...), x, ys...)
2+
3+
function fmap(f, x, ys...; exclude = isleaf,
4+
walk = DefaultWalk(),
5+
cache = IdDict(),
6+
prune = NoKeyword())
7+
_walk = CachedWalk(ExcludeWalk(walk, f, exclude), prune, cache)
8+
fmap(_walk, f, x, ys...)
9+
end
10+
11+
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...)
12+
13+
fcollect(x; exclude = v -> false) =
14+
fmap(ExcludeWalk(CollectWalk(), _ -> nothing, exclude), _ -> nothing, x)

src/walks.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
abstract type AbstractWalk end
2+
3+
struct DefaultWalk <: AbstractWalk end
4+
5+
function (::DefaultWalk)(recurse, x, ys...)
6+
func, re = functor(x)
7+
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
8+
re(map(recurse, func, yfuncs...))
9+
end
10+
11+
struct StructuralWalk <: AbstractWalk end
12+
13+
(::StructuralWalk)(recurse, x) = map(recurse, children(x))
14+
15+
struct ExcludeWalk{T, F, G} <: AbstractWalk
16+
walk::T
17+
fn::F
18+
exclude::G
19+
end
20+
21+
(walk::ExcludeWalk)(recurse, x, ys...) =
22+
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)
23+
24+
struct NoKeyword end
25+
26+
struct CachedWalk{T, S} <: AbstractWalk
27+
walk::T
28+
prune::S
29+
cache::IdDict{Any, Any}
30+
end
31+
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
32+
CachedWalk(walk, prune, cache)
33+
34+
function (walk::CachedWalk)(recurse, x, ys...)
35+
if haskey(walk.cache, x)
36+
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
37+
else
38+
walk.cache[x] = walk.walk(recurse, x, ys...)
39+
return walk.cache[x]
40+
end
41+
end
42+
43+
struct CollectWalk <: AbstractWalk
44+
cache::Base.IdSet{Any}
45+
output::Vector{Any}
46+
end
47+
CollectWalk() = CollectWalk(Base.IdSet(), Any[])
48+
49+
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
50+
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
51+
# for the results, to preserve traversal order (important downstream!).
52+
function (walk::CollectWalk)(recurse, x)
53+
x in walk.cache && return walk.output
54+
# to exclude, we wrap this walk in ExcludeWalk
55+
push!(walk.cache, x)
56+
push!(walk.output, x)
57+
map(recurse, children(x))
58+
59+
return walk.output
60+
end

0 commit comments

Comments
 (0)