Skip to content

Commit 6ce400c

Browse files
committed
move at_thunk-magic into separate macro
1 parent 06ccb14 commit 6ce400c

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/differentials.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,9 @@ end
296296
"""
297297
@thunk body
298298
299-
Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
300-
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
299+
Returns `Thunk(() -> body)`
301300
"""
302301
macro thunk(body)
303-
if body isa Expr && body.head == :call
304-
fname = body.args[1]
305-
if fname in (:Wirtinger, :ComplexGradient)
306-
return :($fname($((:(@thunk $i) for i in body.args[2:end])...)))
307-
end
308-
end
309302
return :(Thunk(() -> $(esc(body))))
310303
end
311304

src/rule_definition_tools.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,32 @@ end
224224
"""
225225
function frule_propagation_expr(𝒟, Δs, ∂s)
226226
∂s = map(esc, ∂s)
227-
∂_mul_Δs = [:(chain(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
227+
∂_mul_Δs = [:(chain(@_thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
228228
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
229229
end
230230

231231
function rrule_propagation_expr(𝒟, Δs, ∂s)
232232
∂s = map(esc, ∂s)
233-
∂_mul_Δs = [:(chain($(Δs[i]), @thunk($(∂s[i])))) for i in 1:length(∂s)]
233+
∂_mul_Δs = [:(chain($(Δs[i]), @_thunk($(∂s[i])))) for i in 1:length(∂s)]
234234
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
235235
end
236236

237+
"""
238+
@_thunk body
239+
240+
Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241+
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242+
"""
243+
macro _thunk(body)
244+
if body isa Expr && body.head == :call
245+
fname = body.args[1]
246+
if fname in (:Wirtinger, :ComplexGradient)
247+
return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...)))
248+
end
249+
end
250+
return :(@thunk $(esc(body)))
251+
end
252+
237253
"""
238254
propagator_name(f, propname)
239255

0 commit comments

Comments
 (0)