Skip to content

Commit 4d21e1a

Browse files
committed
special case AbstractWirtinger in at_thunk
1 parent a83da2f commit 4d21e1a

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/differentials.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ struct Thunk{F} <: AbstractThunk
265265
end
266266

267267
macro thunk(body)
268+
if body isa Expr && body.head == :call
269+
fname = body.args[1]
270+
if fname in (:Wirtinger, :ComplexGradient)
271+
return :($fname($((:(@thunk $i) for i in body.args[2:end])...)))
272+
end
273+
end
268274
return :(Thunk(() -> $(esc(body))))
269275
end
270276

test/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4
9090

9191
# For most differentials, in most domains, this does nothing
92-
for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
92+
for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0)
9393
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
9494
@test refine_differential(𝒟, der) === der
9595
end

0 commit comments

Comments
 (0)