Skip to content

Commit a0d86fe

Browse files
authored
Merge pull request #628 from JuliaDiff/mz/repeat
Fix corner case for repeat
2 parents e4029df + a71bb62 commit a0d86fe

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.35.1"
3+
version = "1.35.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,19 @@ function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
166166
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
167167
end
168168

169-
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
169+
function rrule(::typeof(repeat), xs::AbstractArray; inner=nothing, outer=nothing)
170170

171171
project_Xs = ProjectTo(xs)
172172
S = size(xs)
173+
inner_size = inner === nothing ? ntuple(Returns(1), ndims(xs)) : inner
173174
function repeat_pullback(ȳ)
174175
dY = unthunk(ȳ)
175176
Δ′ = zero(xs)
176177
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
177178
for (dest_idx, val) in pairs(IndexCartesian(), dY)
178-
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
179+
# First, round dest_idx[dim] to nearest gridpoint defined by inner_dims[dim], then
179180
# wrap around based on original size S.
180-
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
181+
src_idx = [mod1(div(dest_idx[dim] - 1, inner_size[dim]) + 1, S[dim]) for dim in 1:length(S)]
181182
Δ′[src_idx...] += val
182183
end
183184
= project_Xs(Δ′)

test/rulesets/Base/array.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ end
128128
test_rrule(repeat, rand(4, 5))
129129
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
130130
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
131+
test_rrule(repeat, rand(4, 5); fkwargs = (outer=2,))
131132

132133
test_rrule(repeat, rand(4, ), 2)
133134
test_rrule(repeat, rand(4, 5), 2)

0 commit comments

Comments
 (0)