File tree Expand file tree Collapse file tree 2 files changed +19
-10
lines changed Expand file tree Collapse file tree 2 files changed +19
-10
lines changed Original file line number Diff line number Diff line change 296
296
"""
297
297
@thunk body
298
298
299
- Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
300
- In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
299
+ Returns `Thunk(() -> body)`
301
300
"""
302
301
macro thunk (body)
303
- if body isa Expr && body. head == :call
304
- fname = body. args[1 ]
305
- if fname in (:Wirtinger , :ComplexGradient )
306
- return :($ fname ($ ((:(@thunk $ i) for i in body. args[2 : end ]). .. )))
307
- end
308
- end
309
302
return :(Thunk (() -> $ (esc (body))))
310
303
end
311
304
Original file line number Diff line number Diff line change @@ -224,16 +224,32 @@ end
224
224
"""
225
225
function frule_propagation_expr (𝒟, Δs, ∂s)
226
226
∂s = map (esc, ∂s)
227
- ∂_mul_Δs = [:(chain (@thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
227
+ ∂_mul_Δs = [:(chain (@_thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228
228
return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
229
229
end
230
230
231
231
function rrule_propagation_expr (𝒟, Δs, ∂s)
232
232
∂s = map (esc, ∂s)
233
- ∂_mul_Δs = [:(chain ($ (Δs[i]), @thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
233
+ ∂_mul_Δs = [:(chain ($ (Δs[i]), @_thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
234
234
return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
235
235
end
236
236
237
+ """
238
+ @_thunk body
239
+
240
+ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241
+ In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242
+ """
243
+ macro _thunk (body)
244
+ if body isa Expr && body. head == :call
245
+ fname = body. args[1 ]
246
+ if fname in (:Wirtinger , :ComplexGradient )
247
+ return :($ fname ($ ((:(@thunk $ (esc (i))) for i in body. args[2 : end ]). .. )))
248
+ end
249
+ end
250
+ return :(@thunk $ (esc (body)))
251
+ end
252
+
237
253
"""
238
254
propagator_name(f, propname)
239
255
You can’t perform that action at this time.
0 commit comments