Skip to content

Commit d2526d2

Browse files
author
Miha Zgubic
committed
corner case for repeat
1 parent e4029df commit d2526d2

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/rulesets/Base/array.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,19 @@ function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
166166
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
167167
end
168168

169-
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
169+
function rrule(::typeof(repeat), xs::AbstractArray; inner=nothing, outer=nothing)
170170

171171
project_Xs = ProjectTo(xs)
172172
S = size(xs)
173+
inner_size = inner === nothing ? ntuple(Returns(1), ndims(xs)) : inner
173174
function repeat_pullback(ȳ)
174175
dY = unthunk(ȳ)
175176
Δ′ = zero(xs)
176177
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
177178
for (dest_idx, val) in pairs(IndexCartesian(), dY)
178-
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
179+
# First, round dest_idx[dim] to nearest gridpoint defined by inner_dims[dim], then
179180
# wrap around based on original size S.
180-
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
181+
src_idx = [mod1(div(dest_idx[dim] - 1, inner_size[dim]) + 1, S[dim]) for dim in 1:length(S)]
181182
Δ′[src_idx...] += val
182183
end
183184
= project_Xs(Δ′)

test/rulesets/Base/array.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ end
128128
test_rrule(repeat, rand(4, 5))
129129
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
130130
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
131+
test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,))
131132

132133
test_rrule(repeat, rand(4, ), 2)
133134
test_rrule(repeat, rand(4, 5), 2)

0 commit comments

Comments
 (0)