@@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
156
156
Δs = [Symbol (string (:Δ , i)) for i in 1 : n_inputs]
157
157
pushforward_returns = map (1 : n_outputs) do output_i
158
158
∂s = partials[output_i]. args
159
- propagation_expr (𝒟, Δs, ∂s)
159
+ frule_propagation_expr (𝒟, Δs, ∂s)
160
160
end
161
161
if n_outputs > 1
162
162
# For forward-mode we only return a tuple if output actually a tuple.
@@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
193
193
# 1 partial derivative per input
194
194
pullback_returns = map (1 : n_inputs) do input_i
195
195
∂s = [partial. args[input_i] for partial in partials]
196
- propagation_expr (𝒟, Δs, ∂s)
196
+ rrule_propagation_expr (𝒟, Δs, ∂s)
197
197
end
198
198
199
199
pullback = quote
@@ -222,56 +222,16 @@ end
222
222
if it is taken at `1+1im` it returns `Complex{Int}`.
223
223
At present it is ignored for non-Wirtinger derivatives.
224
224
"""
225
- function propagation_expr (𝒟, Δs, ∂s)
226
- wirtinger_indices = findall (∂s) do ex
227
- Meta. isexpr (ex, :call ) && ex. args[1 ] === :Wirtinger
228
- end
225
+ function frule_propagation_expr (𝒟, Δs, ∂s)
229
226
∂s = map (esc, ∂s)
230
- if isempty (wirtinger_indices)
231
- return standard_propagation_expr (Δs, ∂s)
232
- else
233
- return wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
234
- end
235
- end
236
-
237
- function standard_propagation_expr (Δs, ∂s)
238
- # This is basically Δs ⋅ ∂s
239
-
240
- # Notice: the thunking of `∂s[i] (potentially) saves us some computation
241
- # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
242
- # as the pullback is evaluated
243
- ∂_mul_Δs = [:(@thunk ($ (∂s[i])) * $ (Δs[i])) for i in 1 : length (∂s)]
244
- return :(+ ($ (∂_mul_Δs... )))
227
+ ∂_mul_Δs = [:(chain (@thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228
+ return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
245
229
end
246
230
247
- function wirtinger_propagation_expr (𝒟, wirtinger_indices, Δs, ∂s)
248
- ∂_mul_Δs_primal = Any[]
249
- ∂_mul_Δs_conjugate = Any[]
250
- ∂_wirtinger_defs = Any[]
251
- for i in 1 : length (∂s)
252
- if i in wirtinger_indices
253
- Δi = Δs[i]
254
- ∂i = Symbol (string (:∂ , i))
255
- push! (∂_wirtinger_defs, :($ ∂i = $ (∂s[i])))
256
- ∂f∂i_mul_Δ = :(wirtinger_primal ($ ∂i) * wirtinger_primal ($ Δi))
257
- ∂f∂ī_mul_Δ̄ = :(conj (wirtinger_conjugate ($ ∂i)) * wirtinger_conjugate ($ Δi))
258
- ∂f̄∂i_mul_Δ = :(wirtinger_conjugate ($ ∂i) * wirtinger_primal ($ Δi))
259
- ∂f̄∂ī_mul_Δ̄ = :(conj (wirtinger_primal ($ ∂i)) * wirtinger_conjugate ($ Δi))
260
- push! (∂_mul_Δs_primal, :($ ∂f∂i_mul_Δ + $ ∂f∂ī_mul_Δ̄))
261
- push! (∂_mul_Δs_conjugate, :($ ∂f̄∂i_mul_Δ + $ ∂f̄∂ī_mul_Δ̄))
262
- else
263
- ∂_mul_Δ = :(@thunk ($ (∂s[i])) * $ (Δs[i]))
264
- push! (∂_mul_Δs_primal, ∂_mul_Δ)
265
- push! (∂_mul_Δs_conjugate, ∂_mul_Δ)
266
- end
267
- end
268
- primal_sum = :(+ ($ (∂_mul_Δs_primal... )))
269
- conjugate_sum = :(+ ($ (∂_mul_Δs_conjugate... )))
270
- return quote # This will be a block, so will have value equal to last statement
271
- $ (∂_wirtinger_defs... )
272
- w = Wirtinger ($ primal_sum, $ conjugate_sum)
273
- refine_differential ($ 𝒟, w)
274
- end
231
+ function rrule_propagation_expr (𝒟, Δs, ∂s)
232
+ ∂s = map (esc, ∂s)
233
+ ∂_mul_Δs = [:(chain ($ (Δs[i]), @thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
234
+ return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
275
235
end
276
236
277
237
"""
0 commit comments