From d94390ab8d9daa8a498151e0208fd0c7ac009614 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 19 Jul 2025 16:42:03 +0200 Subject: [PATCH 1/2] Simplify - with at-j --- src/reverse_mode.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 7b3f051..206946b 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -147,11 +147,11 @@ function _forward_eval( @inbounds ix1 = children_arr[child1] @inbounds ix2 = children_arr[child1+1] for j in _eachindex(f.sizes, k) - tmp_sub = _getindex(f.forward_storage, f.sizes, ix1, j) - tmp_sub -= _getindex(f.forward_storage, f.sizes, ix2, j) - _setindex!(f.partials_storage, one(T), f.sizes, ix1, j) - _setindex!(f.partials_storage, -one(T), f.sizes, ix2, j) - _setindex!(f.forward_storage, tmp_sub, f.sizes, k, j) + tmp_sub = @j f.forward_storage[ix1] + tmp_sub -= @j f.forward_storage[ix2] + @j f.partials_storage[ix1] = one(T) + @j f.partials_storage[ix2] = -one(T) + @j f.forward_storage[k] = tmp_sub end elseif node.index == 3 # :* tmp_prod = one(T) From 083848375d8648d15dcb4192f7a13a5e774a04fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Sat, 19 Jul 2025 16:45:12 +0200 Subject: [PATCH 2/2] Fix format --- src/sizes.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/sizes.jl b/src/sizes.jl index cecadbd..e35f3d6 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -61,10 +61,25 @@ macro j(expr) lhs, rhs = expr.args @assert Meta.isexpr(lhs, :ref) @assert length(expr.args) == 2 - return Expr(:call, :_setindex!, esc(lhs.args[1]), esc(rhs), esc(:(f.sizes)), esc(lhs.args[2]), esc(:j)) + return Expr( + :call, + :_setindex!, + esc(lhs.args[1]), + esc(rhs), + esc(:(f.sizes)), + esc(lhs.args[2]), + esc(:j), + ) elseif Meta.isexpr(expr, :ref) && length(expr.args) == 2 arr, idx = expr.args - return Expr(:call, :_getindex, esc(arr), esc(:(f.sizes)), esc(idx), esc(:j)) + return Expr( + :call, + :_getindex, + esc(arr), + esc(:(f.sizes)), + esc(idx), + esc(:j), + ) else error("Unsupported expression `$expr`") end