@@ -34,60 +34,14 @@ macro functor(args...)
34
34
functorm (args... )
35
35
end
36
36
37
- function makeflexiblefunctor (m:: Module , T, pfield)
38
- pfield = QuoteNode (pfield)
39
- @eval m begin
40
- function $Functors. functor (:: Type{<:$T} , x)
41
- pfields = getproperty (x, $ pfield)
42
- function re (y)
43
- all_args = map (fn -> getproperty (fn in pfields ? y : x, fn), fieldnames ($ T))
44
- return $ T (all_args... )
45
- end
46
- func = NamedTuple {pfields} (map (p -> getproperty (x, p), pfields))
47
- return func, re
48
- end
49
-
50
- end
51
-
52
- end
53
-
54
- function flexiblefunctorm (T, pfield = :params )
55
- pfield isa Symbol || error (" @flexiblefunctor T param_field" )
56
- pfield = QuoteNode (pfield)
57
- :(makeflexiblefunctor (@__MODULE__ , $ (esc (T)), $ (esc (pfield))))
58
- end
59
-
60
- macro flexiblefunctor (args... )
61
- flexiblefunctorm (args... )
62
- end
63
-
64
37
isleaf (x) = children (x) === ()
65
38
66
39
children (x) = functor (x)[1 ]
67
40
68
- function functor_tuple (f, x:: Tuple , dx:: Tuple )
69
- map (x, dx) do x, x̄
70
- _default_walk (f, x, x̄)
71
- end
72
- end
73
- functor_tuple (f, x, dx) = f (x, dx)
74
- functor_tuple (f, x, :: Nothing ) = x
75
-
76
- # @functor Chain
77
- # Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
78
- function _default_walk (f, x, dx)
79
- func, re = functor (x)
80
- map (func, dx) do x, x̄
81
- # functor_tuple(f, x, x̄)
82
- f (x, x̄)
83
- end |> re
84
- end
85
-
86
41
function _default_walk (f, x)
87
42
func, re = functor (x)
88
43
re (map (f, func))
89
44
end
90
- _default_walk (f, :: Nothing , :: Nothing ) = nothing
91
45
92
46
function fmap (f, x; exclude = isleaf, walk = _default_walk, cache = IdDict ())
93
47
haskey (cache, x) && return cache[x]
@@ -97,6 +51,10 @@ function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
97
51
return y
98
52
end
99
53
54
+ # ##
55
+ # ## Extras
56
+ # ##
57
+
100
58
fmapstructure (f, x; kwargs... ) = fmap (f, x; walk = (f, x) -> map (f, children (x)), kwargs... )
101
59
102
60
function fcollect (x; output = [], cache = Base. IdSet (), exclude = v -> false )
@@ -112,10 +70,59 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
112
70
return output
113
71
end
114
72
115
- # Allow gradients and other constructs that match the structure of the functor
116
- # to allow for `map` style computations and return a modified version of the struct.
117
- # This way we can use `fmap` to update the params with their gradients
73
+ # ##
74
+ # ## Vararg forms
75
+ # ##
76
+
118
77
function fmap (f, x, dx... ; cache = IdDict ())
119
78
haskey (cache, x) && return cache[x]
120
79
cache[x] = isleaf (x) ? f (x, dx... ) : _default_walk ((x... ) -> fmap (f, x... , cache = cache), x, dx... )
121
80
end
81
+
82
+ function functor_tuple (f, x:: Tuple , dx:: Tuple )
83
+ map (x, dx) do x, x̄
84
+ _default_walk (f, x, x̄)
85
+ end
86
+ end
87
+ functor_tuple (f, x, dx) = f (x, dx)
88
+ functor_tuple (f, x, :: Nothing ) = x
89
+
90
+ function _default_walk (f, x, dx)
91
+ func, re = functor (x)
92
+ map (func, dx) do x, x̄
93
+ # functor_tuple(f, x, x̄)
94
+ f (x, x̄)
95
+ end |> re
96
+ end
97
+ _default_walk (f, :: Nothing , :: Nothing ) = nothing
98
+
99
+ # ##
100
+ # ## FlexibleFunctors.jl
101
+ # ##
102
+
103
+ function makeflexiblefunctor (m:: Module , T, pfield)
104
+ pfield = QuoteNode (pfield)
105
+ @eval m begin
106
+ function $Functors. functor (:: Type{<:$T} , x)
107
+ pfields = getproperty (x, $ pfield)
108
+ function re (y)
109
+ all_args = map (fn -> getproperty (fn in pfields ? y : x, fn), fieldnames ($ T))
110
+ return $ T (all_args... )
111
+ end
112
+ func = NamedTuple {pfields} (map (p -> getproperty (x, p), pfields))
113
+ return func, re
114
+ end
115
+
116
+ end
117
+
118
+ end
119
+
120
+ function flexiblefunctorm (T, pfield = :params )
121
+ pfield isa Symbol || error (" @flexiblefunctor T param_field" )
122
+ pfield = QuoteNode (pfield)
123
+ :(makeflexiblefunctor (@__MODULE__ , $ (esc (T)), $ (esc (pfield))))
124
+ end
125
+
126
+ macro flexiblefunctor (args... )
127
+ flexiblefunctorm (args... )
128
+ end
0 commit comments