Skip to content

Commit 451a26e

Browse files
committed
move the functions around
1 parent 6284ad9 commit 451a26e

File tree

1 file changed

+56
-49
lines changed

1 file changed

+56
-49
lines changed

src/functor.jl

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -34,60 +34,14 @@ macro functor(args...)
3434
functorm(args...)
3535
end
3636

37-
function makeflexiblefunctor(m::Module, T, pfield)
38-
pfield = QuoteNode(pfield)
39-
@eval m begin
40-
function $Functors.functor(::Type{<:$T}, x)
41-
pfields = getproperty(x, $pfield)
42-
function re(y)
43-
all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T))
44-
return $T(all_args...)
45-
end
46-
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
47-
return func, re
48-
end
49-
50-
end
51-
52-
end
53-
54-
function flexiblefunctorm(T, pfield = :params)
55-
pfield isa Symbol || error("@flexiblefunctor T param_field")
56-
pfield = QuoteNode(pfield)
57-
:(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield))))
58-
end
59-
60-
macro flexiblefunctor(args...)
61-
flexiblefunctorm(args...)
62-
end
63-
6437
isleaf(x) = children(x) === ()
6538

6639
children(x) = functor(x)[1]
6740

68-
function functor_tuple(f, x::Tuple, dx::Tuple)
69-
map(x, dx) do x, x̄
70-
_default_walk(f, x, x̄)
71-
end
72-
end
73-
functor_tuple(f, x, dx) = f(x, dx)
74-
functor_tuple(f, x, ::Nothing) = x
75-
76-
# @functor Chain
77-
# Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
78-
function _default_walk(f, x, dx)
79-
func, re = functor(x)
80-
map(func, dx) do x, x̄
81-
# functor_tuple(f, x, x̄)
82-
f(x, x̄)
83-
end |> re
84-
end
85-
8641
function _default_walk(f, x)
8742
func, re = functor(x)
8843
re(map(f, func))
8944
end
90-
_default_walk(f, ::Nothing, ::Nothing) = nothing
9145

9246
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
9347
haskey(cache, x) && return cache[x]
@@ -97,6 +51,10 @@ function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
9751
return y
9852
end
9953

54+
###
55+
### Extras
56+
###
57+
10058
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)
10159

10260
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
@@ -112,10 +70,59 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
11270
return output
11371
end
11472

115-
# Allow gradients and other constructs that match the structure of the functor
116-
# to allow for `map` style computations and return a modified version of the struct.
117-
# This way we can use `fmap` to update the params with their gradients
73+
###
74+
### Vararg forms
75+
###
76+
11877
function fmap(f, x, dx...; cache = IdDict())
11978
haskey(cache, x) && return cache[x]
12079
cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...)
12180
end
81+
82+
function functor_tuple(f, x::Tuple, dx::Tuple)
83+
map(x, dx) do x, x̄
84+
_default_walk(f, x, x̄)
85+
end
86+
end
87+
functor_tuple(f, x, dx) = f(x, dx)
88+
functor_tuple(f, x, ::Nothing) = x
89+
90+
function _default_walk(f, x, dx)
91+
func, re = functor(x)
92+
map(func, dx) do x, x̄
93+
# functor_tuple(f, x, x̄)
94+
f(x, x̄)
95+
end |> re
96+
end
97+
_default_walk(f, ::Nothing, ::Nothing) = nothing
98+
99+
###
100+
### FlexibleFunctors.jl
101+
###
102+
103+
function makeflexiblefunctor(m::Module, T, pfield)
104+
pfield = QuoteNode(pfield)
105+
@eval m begin
106+
function $Functors.functor(::Type{<:$T}, x)
107+
pfields = getproperty(x, $pfield)
108+
function re(y)
109+
all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T))
110+
return $T(all_args...)
111+
end
112+
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
113+
return func, re
114+
end
115+
116+
end
117+
118+
end
119+
120+
function flexiblefunctorm(T, pfield = :params)
121+
pfield isa Symbol || error("@flexiblefunctor T param_field")
122+
pfield = QuoteNode(pfield)
123+
:(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield))))
124+
end
125+
126+
macro flexiblefunctor(args...)
127+
flexiblefunctorm(args...)
128+
end

0 commit comments

Comments
 (0)