@@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...)
88
88
)
89
89
f = call. args[1 ]
90
90
91
- frule_expr = scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
92
- rrule_expr = scalar_rrule_expr (__source__, f, call, setup_stmts, inputs, partials)
91
+ # Generate variables to store derivatives named dfi/dxj
92
+ derivatives = map (keys (partials)) do i
93
+ syms = map (j -> Symbol (" ∂f" , i, " /∂x" , j), keys (inputs))
94
+ return Expr (:tuple , syms... )
95
+ end
96
+
97
+ derivative_expr = scalar_derivative_expr (__source__, f, setup_stmts, inputs, partials)
98
+ frule_expr = scalar_frule_expr (__source__, f, call, [], inputs, derivatives)
99
+ rrule_expr = scalar_rrule_expr (__source__, f, call, [], inputs, derivatives)
93
100
94
101
# Final return: building the expression to insert in the place of this macro
95
102
code = quote
@@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...)
99
106
))
100
107
end
101
108
109
+ $ (derivative_expr)
102
110
$ (frule_expr)
103
111
$ (rrule_expr)
104
112
end
@@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
135
143
# For consistency in code that follows we make all partials tuple expressions
136
144
partials = map (partials) do partial
137
145
if Meta. isexpr (partial, :tuple )
138
- partial
146
+ Expr ( :tuple , map (esc, partial. args) ... )
139
147
else
140
148
length (inputs) == 1 || error (" Invalid use of `@scalar_rule`" )
141
- Expr (:tuple , partial)
149
+ Expr (:tuple , esc ( partial) )
142
150
end
143
151
end
144
152
145
153
return call, setup_stmts, inputs, partials
146
154
end
147
155
156
+ """
157
+ derivatives_given_output(Ω, f, xs...)
158
+
159
+ Compute the derivative of scalar function `f` at primal input point `xs...`,
160
+ given that it had primal output `Ω`.
161
+ Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`.
162
+ The derivative of the `i`-th component of `f` with respect to the `j`-th input can be
163
+ accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`.
164
+
165
+ !!! warning "Experimental"
166
+ This function is experimental and not part of the stable API.
167
+ At the moment, it can be considered an implementation detail of the macro
168
+ [`@scalar_rule`](@ref), in which it is used.
169
+ In the future, the exact semantics of this function will stabilize, and it
170
+ will be added to the stable API.
171
+ When that happens, this warning will be removed.
172
+
173
+ """
174
+ function derivatives_given_output end
175
+
176
+ function scalar_derivative_expr (__source__, f, setup_stmts, inputs, partials)
177
+ return @strip_linenos quote
178
+ function ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), :: Core.Typeof ($ f), $ (inputs... ))
179
+ $ (__source__)
180
+ $ (setup_stmts... )
181
+ return $ (Expr (:tuple , partials... ))
182
+ end
183
+ end
184
+ end
148
185
149
186
function scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
150
187
n_outputs = length (partials)
@@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
173
210
$ (__source__)
174
211
$ (esc (:Ω )) = $ call
175
212
$ (setup_stmts... )
213
+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
176
214
return $ (esc (:Ω )), $ pushforward_returns
177
215
end
178
216
end
@@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
210
248
$ (__source__)
211
249
$ (esc (:Ω )) = $ call
212
250
$ (setup_stmts... )
251
+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
213
252
return $ (esc (:Ω )), $ pullback
214
253
end
215
254
end
@@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
240
279
# This is basically Δs ⋅ ∂s
241
280
_∂s = map (∂s) do ∂s_i
242
281
if _conj
243
- :(conj ($ ( esc ( ∂s_i)) ))
282
+ :(conj ($ ∂s_i))
244
283
else
245
- esc ( ∂s_i)
284
+ ∂s_i
246
285
end
247
286
end
248
287
0 commit comments