Skip to content

Commit 61a60ec

Browse files
committed
use the new chain function in @scalar_rule
1 parent 2ac8801 commit 61a60ec

File tree

1 file changed

+9
-49
lines changed

1 file changed

+9
-49
lines changed

src/rule_definition_tools.jl

Lines changed: 9 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
156156
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
157157
pushforward_returns = map(1:n_outputs) do output_i
158158
∂s = partials[output_i].args
159-
propagation_expr(𝒟, Δs, ∂s)
159+
frule_propagation_expr(𝒟, Δs, ∂s)
160160
end
161161
if n_outputs > 1
162162
# For forward-mode we only return a tuple if output actually a tuple.
@@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
193193
# 1 partial derivative per input
194194
pullback_returns = map(1:n_inputs) do input_i
195195
∂s = [partial.args[input_i] for partial in partials]
196-
propagation_expr(𝒟, Δs, ∂s)
196+
rrule_propagation_expr(𝒟, Δs, ∂s)
197197
end
198198

199199
pullback = quote
@@ -222,56 +222,16 @@ end
222222
if it is taken at `1+1im` it returns `Complex{Int}`.
223223
At present it is ignored for non-Wirtinger derivatives.
224224
"""
225-
function propagation_expr(𝒟, Δs, ∂s)
226-
wirtinger_indices = findall(∂s) do ex
227-
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
228-
end
225+
function frule_propagation_expr(𝒟, Δs, ∂s)
229226
∂s = map(esc, ∂s)
230-
if isempty(wirtinger_indices)
231-
return standard_propagation_expr(Δs, ∂s)
232-
else
233-
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
234-
end
235-
end
236-
237-
function standard_propagation_expr(Δs, ∂s)
238-
# This is basically Δs ⋅ ∂s
239-
240-
# Notice: the thunking of `∂s[i] (potentially) saves us some computation
241-
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
242-
# as the pullback is evaluated
243-
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
244-
return :(+($(∂_mul_Δs...)))
227+
∂_mul_Δs = [:(chain(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
228+
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
245229
end
246230

247-
function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
248-
∂_mul_Δs_primal = Any[]
249-
∂_mul_Δs_conjugate = Any[]
250-
∂_wirtinger_defs = Any[]
251-
for i in 1:length(∂s)
252-
if i in wirtinger_indices
253-
Δi = Δs[i]
254-
∂i = Symbol(string(:∂, i))
255-
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
256-
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
257-
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
258-
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
259-
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
260-
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
261-
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
262-
else
263-
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
264-
push!(∂_mul_Δs_primal, ∂_mul_Δ)
265-
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
266-
end
267-
end
268-
primal_sum = :(+($(∂_mul_Δs_primal...)))
269-
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
270-
return quote # This will be a block, so will have value equal to last statement
271-
$(∂_wirtinger_defs...)
272-
w = Wirtinger($primal_sum, $conjugate_sum)
273-
refine_differential($𝒟, w)
274-
end
231+
function rrule_propagation_expr(𝒟, Δs, ∂s)
232+
∂s = map(esc, ∂s)
233+
∂_mul_Δs = [:(chain($(Δs[i]), @thunk($(∂s[i])))) for i in 1:length(∂s)]
234+
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
275235
end
276236

277237
"""

0 commit comments

Comments
 (0)