Skip to content

Commit 97e3669

Browse files
authored
Simplify - with at-j (#9)
* Simplify - with at-j * Fix format
1 parent 525dd8f commit 97e3669

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

src/reverse_mode.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ function _forward_eval(
147147
@inbounds ix1 = children_arr[child1]
148148
@inbounds ix2 = children_arr[child1+1]
149149
for j in _eachindex(f.sizes, k)
150-
tmp_sub = _getindex(f.forward_storage, f.sizes, ix1, j)
151-
tmp_sub -= _getindex(f.forward_storage, f.sizes, ix2, j)
152-
_setindex!(f.partials_storage, one(T), f.sizes, ix1, j)
153-
_setindex!(f.partials_storage, -one(T), f.sizes, ix2, j)
154-
_setindex!(f.forward_storage, tmp_sub, f.sizes, k, j)
150+
tmp_sub = @j f.forward_storage[ix1]
151+
tmp_sub -= @j f.forward_storage[ix2]
152+
@j f.partials_storage[ix1] = one(T)
153+
@j f.partials_storage[ix2] = -one(T)
154+
@j f.forward_storage[k] = tmp_sub
155155
end
156156
elseif node.index == 3 # :*
157157
tmp_prod = one(T)

src/sizes.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,25 @@ macro j(expr)
6161
lhs, rhs = expr.args
6262
@assert Meta.isexpr(lhs, :ref)
6363
@assert length(expr.args) == 2
64-
return Expr(:call, :_setindex!, esc(lhs.args[1]), esc(rhs), esc(:(f.sizes)), esc(lhs.args[2]), esc(:j))
64+
return Expr(
65+
:call,
66+
:_setindex!,
67+
esc(lhs.args[1]),
68+
esc(rhs),
69+
esc(:(f.sizes)),
70+
esc(lhs.args[2]),
71+
esc(:j),
72+
)
6573
elseif Meta.isexpr(expr, :ref) && length(expr.args) == 2
6674
arr, idx = expr.args
67-
return Expr(:call, :_getindex, esc(arr), esc(:(f.sizes)), esc(idx), esc(:j))
75+
return Expr(
76+
:call,
77+
:_getindex,
78+
esc(arr),
79+
esc(:(f.sizes)),
80+
esc(idx),
81+
esc(:j),
82+
)
6883
else
6984
error("Unsupported expression `$expr`")
7085
end

0 commit comments

Comments
 (0)