Skip to content

Commit b960a09

Browse files
committed
make at_scalar_rule detect wrong Wirtinger rules
Tests will break for now, and that's good.
1 parent 6ce400c commit b960a09

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/rule_definition_tools.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,33 @@ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) o
241241
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242242
"""
243243
macro _thunk(body)
244-
if body isa Expr && body.head == :call
245-
fname = body.args[1]
246-
if fname in (:Wirtinger, :ComplexGradient)
247-
return :($fname($((:(@thunk $(esc(i))) for i in body.args[2:end])...)))
244+
return _thunk(body)
245+
end
246+
247+
function _thunk(body)
248+
if body isa Expr
249+
if body.head == :call
250+
fname = body.args[1]
251+
if fname in (:Wirtinger, :ComplexGradient)
252+
return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...)))
253+
end
254+
elseif body.head == :escape
255+
return Expr(:escape, _thunk(body.args[1]))
248256
end
249257
end
250-
return :(@thunk $(esc(body)))
258+
return thunk_assert_no_wirtinger(body)
259+
end
260+
261+
thunk_assert_no_wirtinger(body) = quote
262+
Thunk(
263+
function()
264+
res = $(esc(body))
265+
res isa AbstractWirtinger && error("""
266+
Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule.
267+
Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""")
268+
return res
269+
end
270+
)
251271
end
252272

253273
"""

0 commit comments

Comments
 (0)