Skip to content

Commit 8774ada

Browse files
committed
implement some of @oxinabox's suggestions
1 parent 3bc7e22 commit 8774ada

File tree

3 files changed

+6
-17
lines changed

3 files changed

+6
-17
lines changed

src/differentials.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ wirtinger_conjugate(::Any) = Zero()
6666

6767
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))
6868

69-
Base.iterate(x::AbstractWirtinger) = (x, nothing)
70-
Base.iterate(::AbstractWirtinger, ::Any) = nothing
71-
7269
# `conj` is not defined for `AbstractWirtinger`.
7370
# Need this method to override the definition of `conj` for `AbstractDifferential`.
7471
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))
@@ -102,9 +99,6 @@ wirtinger_conjugate(x::Wirtinger) = x.conjugate
10299
Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
103100
broadcastable(w.conjugate))
104101

105-
Base.iterate(x::Wirtinger) = (x, nothing)
106-
Base.iterate(::Wirtinger, ::Any) = nothing
107-
108102
#####
109103
##### `ComplexGradient`
110104
#####

src/rule_definition_tools.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,13 @@ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) o
241241
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242242
"""
243243
function _thunk(body)
244-
if body isa Expr
245-
if body.head == :call
246-
fname = body.args[1]
247-
if fname in (:Wirtinger, :ComplexGradient)
248-
return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...)))
249-
end
250-
elseif body.head == :escape
251-
return Expr(:escape, _thunk(body.args[1]))
244+
if Meta.isexpr(body, :call)
245+
fname = body.args[1]
246+
if fname in (:Wirtinger, :ComplexGradient)
247+
return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...)))
252248
end
249+
elseif Meta.isexpr(body, :escape)
250+
return Expr(:escape, _thunk(body.args[1]))
253251
end
254252
return thunk_assert_no_wirtinger(body)
255253
end

test/differentials.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
# TODO: other + methods stack overflow
1313
@test_throws ErrorException w*w
1414
@test_throws ArgumentError extern(w)
15-
for x in w
16-
@test x === w
17-
end
1815
@test broadcastable(w) == w
1916
@test_throws MethodError conj(w)
2017
end

0 commit comments

Comments
 (0)