diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 56e02b02a..d63c707e7 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -294,8 +294,8 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. - init_expr = :(*($(_∂s[1]), $(Δs[1]))) - _∂s_Δs_tail = Iterators.drop(zip(_∂s, Δs), 1) + (∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs)) + init_expr = :($∂s_1 * $Δs_1) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) :(muladd($∂s_i, $Δs_i, $ex)) end