Skip to content

Commit 3bc7e22

Browse files
committed
use _thunk as a function, not macro
1 parent b960a09 commit 3bc7e22

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/rule_definition_tools.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,26 +224,22 @@ end
224224
"""
225225
function frule_propagation_expr(𝒟, Δs, ∂s)
226226
∂s = map(esc, ∂s)
227-
∂_mul_Δs = [:(chain(@_thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
227+
∂_mul_Δs = [:(chain($(_thunk(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
228228
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
229229
end
230230

231231
function rrule_propagation_expr(𝒟, Δs, ∂s)
232232
∂s = map(esc, ∂s)
233-
∂_mul_Δs = [:(chain($(Δs[i]), @_thunk($(∂s[i])))) for i in 1:length(∂s)]
233+
∂_mul_Δs = [:(chain($(Δs[i]), $(_thunk(∂s[i])))) for i in 1:length(∂s)]
234234
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
235235
end
236236

237237
"""
238-
@_thunk body
238+
_thunk(body)
239239
240240
Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241241
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242242
"""
243-
macro _thunk(body)
244-
return _thunk(body)
245-
end
246-
247243
function _thunk(body)
248244
if body isa Expr
249245
if body.head == :call
@@ -261,8 +257,8 @@ end
261257
thunk_assert_no_wirtinger(body) = quote
262258
Thunk(
263259
function()
264-
res = $(esc(body))
265-
res isa AbstractWirtinger && error("""
260+
res = $body
261+
res isa ChainRulesCore.AbstractWirtinger && error("""
266262
Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule.
267263
Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""")
268264
return res

0 commit comments

Comments
 (0)