File tree Expand file tree Collapse file tree 1 file changed +5
-9
lines changed Expand file tree Collapse file tree 1 file changed +5
-9
lines changed Original file line number Diff line number Diff line change @@ -224,26 +224,22 @@ end
224
224
"""
225
225
function frule_propagation_expr (𝒟, Δs, ∂s)
226
226
∂s = map (esc, ∂s)
227
- ∂_mul_Δs = [:(chain (@_thunk ( $ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
227
+ ∂_mul_Δs = [:(chain ($ ( _thunk (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228
228
return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
229
229
end
230
230
231
231
function rrule_propagation_expr (𝒟, Δs, ∂s)
232
232
∂s = map (esc, ∂s)
233
- ∂_mul_Δs = [:(chain ($ (Δs[i]), @_thunk ( $ (∂s[i])))) for i in 1 : length (∂s)]
233
+ ∂_mul_Δs = [:(chain ($ (Δs[i]), $ ( _thunk (∂s[i])))) for i in 1 : length (∂s)]
234
234
return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
235
235
end
236
236
237
237
"""
238
- @ _thunk body
238
+ _thunk( body)
239
239
240
240
Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241
241
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242
242
"""
243
- macro _thunk (body)
244
- return _thunk (body)
245
- end
246
-
247
243
function _thunk (body)
248
244
if body isa Expr
249
245
if body. head == :call
261
257
thunk_assert_no_wirtinger (body) = quote
262
258
Thunk (
263
259
function ()
264
- res = $ ( esc ( body))
265
- res isa AbstractWirtinger && error ("""
260
+ res = $ body
261
+ res isa ChainRulesCore . AbstractWirtinger && error ("""
266
262
Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule.
267
263
Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""" )
268
264
return res
You can’t perform that action at this time.
0 commit comments