Skip to content

Commit 1fb2761

Browse files
AlexRobsonmzgubic
andauthored
Add missing generic repeat method (#466)
* Add generic repeat method. Add tests * Work through some 1.0 discrepancies * Add type inference hint * Remove type hinting and relax tests for 1.0 * Add in Project and @thunk * White space del * Wrap in thunk * Update Project.toml Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * Remove thunks Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent 2e6491c commit 1fb2761

File tree

3 files changed

+34
-41
lines changed

3 files changed

+34
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.8.1"
3+
version = "1.9"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,54 +68,39 @@ end
6868
#####
6969
##### `repeat`
7070
#####
71-
7271
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
7372

73+
project_Xs = ProjectTo(xs)
74+
S = size(xs)
7475
function repeat_pullback(ȳ)
7576
dY = unthunk(ȳ)
7677
Δ′ = zero(xs)
77-
S = size(xs)
78-
7978
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
8079
for (dest_idx, val) in pairs(IndexCartesian(), dY)
8180
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
8281
# wrap around based on original size S.
8382
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
8483
Δ′[src_idx...] += val
8584
end
86-
return (NoTangent(), Δ′)
85+
= project_Xs(Δ′)
86+
return (NoTangent(), x̄)
8787
end
8888

8989
return repeat(xs; inner = inner, outer = outer), repeat_pullback
9090
end
9191

92-
function rrule(::typeof(repeat), xs::AbstractVector, m::Integer)
92+
function rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...)
9393

94-
d1 = size(xs, 1)
94+
project_Xs = ProjectTo(xs)
95+
S = size(xs)
9596
function repeat_pullback(ȳ)
96-
Δ′ = dropdims(sum(reshape(ȳ, d1, :); dims=2); dims=2)
97-
return (NoTangent(), Δ′, NoTangent())
98-
end
99-
100-
return repeat(xs, m), repeat_pullback
101-
end
102-
103-
function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer)
104-
d1, d2 = size(xs, 1), size(xs, 2)
105-
function repeat_pullback(ȳ)
106-
ȳ′ = reshape(ȳ, d1, m, d2, n)
107-
return NoTangent(), reshape(sum(ȳ′; dims=(2,4)), (d1, d2)), NoTangent(), NoTangent()
97+
dY = unthunk(ȳ)
98+
size2ndims = ntuple(d -> isodd(d) ? get(S, 1+d÷2, 1) : get(counts, d÷2, 1), 2*ndims(dY))
99+
reduced = sum(reshape(dY, size2ndims); dims = ntuple(d -> 2d, ndims(dY)))
100+
= project_Xs(reshape(reduced, S))
101+
return (NoTangent(), x̄, map(_->NoTangent(), counts)...)
108102
end
109-
110-
return repeat(xs, m, n), repeat_pullback
111-
end
112-
113-
function rrule(T::typeof(repeat), xs::AbstractVecOrMat, m::Integer)
114-
115-
# Workaround use of positional default (i.e. repeat(xs, m, n = 1)))
116-
y, full_pb = rrule(T, xs, m, 1)
117-
repeat_pullback(ȳ) = full_pb(ȳ)[1:3]
118-
return y, repeat_pullback
103+
return repeat(xs, counts...), repeat_pullback
119104
end
120105

121106
#####

test/rulesets/Base/array.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,35 @@ end
4343
end
4444

4545
@testset "repeat" begin
46+
4647
test_rrule(repeat, rand(4, ))
47-
test_rrule(repeat, rand(4, ), 2)
4848
test_rrule(repeat, rand(4, 5))
4949
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
5050
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
5151

52+
test_rrule(repeat, rand(4, ), 2; check_inferred=VERSION>=v"1.6")
53+
test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.6")
54+
test_rrule(repeat, rand(4, 5), 2, 3; check_inferred=VERSION>=v"1.6")
55+
test_rrule(repeat, rand(1,2,3), 2,3,4; check_inferred=VERSION>v"1.6")
56+
test_rrule(repeat, rand(0,2,3), 2,0,4; check_inferred=VERSION>v"1.6")
57+
test_rrule(repeat, rand(1,1,1,1), 2,3,4,5; check_inferred=VERSION>v"1.6")
58+
59+
5260
if VERSION>=v"1.6"
53-
# repeat([1 2; 3 4], inner=(2,4), outer=(1,1,1,3)) fails for v<1.6
61+
# These are cases where repeat itself fails in earlier versions
5462
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3)))
55-
end
56-
test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.5")
57-
test_rrule(repeat, rand(4, 5), 2, 3)
63+
test_rrule(repeat, rand(1,2,3), 2,3)
64+
test_rrule(repeat, rand(1,2,3), 2,3,4,2)
65+
test_rrule(repeat, fill(1.0), 2)
66+
test_rrule(repeat, fill(1.0), 2, 3)
5867

59-
# zero-arrays: broken
60-
@test_broken rrule(repeat, fill(1.0), 2) !== nothing
61-
@test_broken rrule(repeat, fill(1.0), 2, 3) !== nothing
68+
# These fail for other v1.0 related issues (add!!)
69+
# v"1.0": fill(1.0) + fill(1.0) != fill(2.0)
70+
# v"1.6: fill(1.0) + fill(1.0) == fill(2.0) # Expected
71+
test_rrule(repeat, fill(1.0); fkwargs = (inner=2,))
72+
test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,))
6273

63-
# These dispatch but probably needs
64-
# https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179
65-
# test_rrule(repeat, fill(1.0); fkwargs = (inner=2,))
66-
# test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,))
74+
end
6775

6876
@test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4]
6977
@test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4]

0 commit comments

Comments
 (0)