@@ -14,19 +14,22 @@ methods for `frule` and `rrule`:
14
14
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
15
15
Ω = f(x₁, x₂, ...)
16
16
\$ (statement₁, statement₂, ...)
17
- return Ω, (ZERO_RULE,
18
- Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
19
- Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
20
- ...)
17
+ return Ω, (_, Δx₁, Δx₂, ...) -> (
18
+ (∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
19
+ (∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
20
+ ...
21
+ )
21
22
end
22
23
23
24
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
24
25
Ω = f(x₁, x₂, ...)
25
26
\$ (statement₁, statement₂, ...)
26
- return Ω, (NO_FIELDS_RULE,
27
- Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
28
- Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
29
- ...)
27
+ return Ω, (ΔΩ₁, ΔΩ₂, ...) -> (
28
+ NO_FIELDS,
29
+ ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
30
+ ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
31
+ ...
32
+ )
30
33
end
31
34
32
35
If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
@@ -36,10 +39,10 @@ Constraints may also be explicitly be provided to override the `Number` constrai
36
39
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
37
40
`Number`.
38
41
39
- At present this does not support defining rules for closures/functors.
40
- This the first returned rule, representing the derivative with respect to the
41
- function itself, is always the `NO_FIELDS_RULE` (reverse-mode),
42
- or `ZERO_RULE` (forward -mode) .
42
+ At present this does not support defining for closures/functors.
43
+ Thus in reverse-mode, the first returned partial,
44
+ representing the derivative with respect to the function itself, is always `NO_FIELDS`.
45
+ And in forwards -mode, the first input to the returned propergator is always ignored .
43
46
44
47
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
45
48
allows the primal result to be conveniently referenced (as `Ω`) within the
@@ -86,11 +89,54 @@ macro scalar_rule(call, maybe_setup, partials...)
86
89
call. args[i] = esc (arg)
87
90
end
88
91
end
89
- if all (Meta. isexpr (partial, :tuple ) for partial in partials)
92
+
93
+ partials = map (partials) do partial
94
+ if Meta. isexpr (partial, :tuple )
95
+ partial
96
+ else
97
+ @assert length (inputs) == 1
98
+ Expr (:tuple , partial)
99
+ end
100
+ end
101
+ @show partials
102
+
103
+ # ###########################################################
104
+ # Make pullback
105
+ # (TODO : move to own function)
106
+ # TODO : Wirtinger
107
+
108
+ Δs = [Symbol (string (:Δ , i)) for i in 1 : length (partials)]
109
+ pullback_returns = map (eachindex (inputs)) do input_i
110
+ ∂s = [partials. args[input_i] for partial in partials]
111
+ ∂s = map (esc, ∂s)
112
+
113
+ # Notice: the thunking of `∂s[i] (potentially) saves us some computation
114
+ # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
115
+ # as the pullback is evaluated
116
+ ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
117
+ :(+ ($ (∂_mul_Δs... )))
118
+ else
119
+
120
+ pullback = quote
121
+ function $ (Symbol (nameof (f), :_pullback ))($ (Δs... ))
122
+ return (ChainRulesCore. NO_FIELDS, $ (pullback_returns... ))
123
+ end
124
+ end
125
+
126
+ # #######################################
127
+ quote
128
+ function ChainRulesCore. rrule (:: typeof ($ f), $ (inputs... ))
129
+ $ (esc (:Ω )) = $ call
130
+ $ (setup_stmts... )
131
+ return $ (esc (:Ω )), $ esc (pullback)
132
+ end
133
+ end
134
+ end
135
+ #= =
136
+ if !all(Meta.isexpr(partial, :tuple) for partial in partials)
90
137
input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input
91
138
forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials]
92
- reverse_rules = Any[]
93
- for i in 1 : length (inputs)
139
+ reverse_rules = map(1:length(inputs) do i
94
140
reverse_partials = [partial.args[i] for partial in partials]
95
141
push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...))
96
142
end
@@ -103,7 +149,7 @@ macro scalar_rule(call, maybe_setup, partials...)
103
149
# First pseudo-partial is derivative WRT function itself. Since this macro does not
104
150
# support closures, it is just the empty NamedTuple
105
151
forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
106
- reverse_rules = Expr (:tuple , NO_FIELDS_RULE , reverse_rules... )
152
+ reverse_rules = Expr(:tuple, NO_FIELDS , reverse_rules...)
107
153
return quote
108
154
if fieldcount(typeof($f)) > 0
109
155
throw(ArgumentError(
@@ -123,7 +169,13 @@ macro scalar_rule(call, maybe_setup, partials...)
123
169
end
124
170
end
125
171
end
172
+ ==#
173
+
174
+ @macroexpand (@scalar_rule (one (x), Zero ()))
175
+
176
+
126
177
178
+ #= =
127
179
function rule_from_partials(input_arg, ∂s...)
128
180
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
129
181
∂s = map(esc, ∂s)
@@ -161,3 +213,4 @@ function rule_from_partials(input_arg, ∂s...)
161
213
end
162
214
end
163
215
end
216
+ ==#
0 commit comments