Skip to content

Commit 186009a

Browse files
committed
make swap_order in chain a positional arg
This, together with adding `@inline` makes constant-propagation possible. Also fix a bug from before.
1 parent e3ce538 commit 186009a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/differential_arithmetic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ for T in (:Any,)
115115
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
116116
end
117117

118-
function chain(outer, inner; swap_order=false)
118+
@inline function chain(outer, inner, swap_order=false)
119119
if swap_order
120120
return Wirtinger(
121121
wirtinger_primal(inner) * wirtinger_primal(outer) +
@@ -132,10 +132,10 @@ function chain(outer, inner; swap_order=false)
132132
) |> refine_differential
133133
end
134134

135-
function chain(outer::ComplexGradient, inner; swap_order=false)
135+
@inline function chain(outer::ComplexGradient, inner, swap_order=false)
136136
if swap_order
137137
return ComplexGradient(
138-
wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)) *
138+
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) *
139139
outer.val
140140
)
141141
end
@@ -145,7 +145,7 @@ function chain(outer::ComplexGradient, inner; swap_order=false)
145145
)
146146
end
147147

148-
function chain(outer::ComplexGradient, inner::ComplexGradient; swap_order=false)
148+
@inline function chain(outer::ComplexGradient, inner::ComplexGradient, swap_order=false)
149149
if swap_order
150150
return ComplexGradient(conj(inner.val) * outer.val)
151151
end

0 commit comments

Comments
 (0)