Skip to content

Commit 525dd8f

Browse files
authored
Simplify with at-j macro (#8)
1 parent 103bd6e commit 525dd8f

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/reverse_mode.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,10 @@ function _forward_eval(
136136
tmp_sum = zero(T)
137137
for c_idx in children_indices
138138
ix = children_arr[c_idx]
139-
_setindex!(f.partials_storage, one(T), f.sizes, ix, j)
140-
f.partials_storage[ix] = one(T)
141-
tmp_sum += _getindex(f.forward_storage, f.sizes, ix, j)
139+
@j f.partials_storage[ix] = one(T)
140+
tmp_sum += @j f.forward_storage[ix]
142141
end
143-
_setindex!(f.forward_storage, tmp_sum, f.sizes, k, j)
142+
@j f.forward_storage[k] = tmp_sum
144143
end
145144
elseif node.index == 2 # :-
146145
@assert N == 2

src/sizes.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,27 @@ function _setindex!(x, value, sizes::Sizes, k::Int, j)
4949
return x[sizes.storage_offset[k]+j] = value
5050
end
5151

52+
"""
53+
@j(storage[node]) -> _getindex(storage, f.sizes, node, j)
54+
@j(storage[node] = value) -> _setindex!(storage, value, f.sizes, node, j)
55+
56+
This "at `j`" converts `getindex` and `setindex!` calls to access
57+
the sub-array in a vector corresponding to a node at its `j`th index.
58+
"""
59+
macro j(expr)
60+
if Meta.isexpr(expr, :(=)) && length(expr.args) == 2
61+
lhs, rhs = expr.args
62+
@assert Meta.isexpr(lhs, :ref)
63+
@assert length(expr.args) == 2
64+
return Expr(:call, :_setindex!, esc(lhs.args[1]), esc(rhs), esc(:(f.sizes)), esc(lhs.args[2]), esc(:j))
65+
elseif Meta.isexpr(expr, :ref) && length(expr.args) == 2
66+
arr, idx = expr.args
67+
return Expr(:call, :_getindex, esc(arr), esc(:(f.sizes)), esc(idx), esc(:j))
68+
else
69+
error("Unsupported expression `$expr`")
70+
end
71+
end
72+
5273
# /!\ Can only be called in decreasing `k` order
5374
function _add_size!(sizes::Sizes, k::Int, size::Tuple)
5475
sizes.ndims[k] = length(size)

0 commit comments

Comments
 (0)