Skip to content

Commit d8baa8f

Browse files
authored
Port more array-related derivatives from Nabla (#45)
* reshape * hcat * vcat * fill
1 parent 47d0202 commit d8baa8f

File tree

4 files changed

+116
-0
lines changed

4 files changed

+116
-0
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export AbstractRule, Rule, frule, rrule
1111
include("differentials.jl")
1212
include("rules.jl")
1313
include("rules/base.jl")
14+
include("rules/array.jl")
1415
include("rules/broadcast.jl")
1516
include("rules/linalg/dense.jl")
1617
include("rules/linalg/diagonal.jl")

src/rules/array.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#####
2+
##### `reshape`
3+
#####
4+
5+
function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
6+
return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule())
7+
end
8+
9+
function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
10+
Y, (rule, _) = rrule(reshape, A, dims)
11+
return Y, (rule, fill(DNERule(), length(dims))...)
12+
end
13+
14+
#####
15+
##### `hcat` (🐈)
16+
#####
17+
18+
function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
19+
Y = hcat(A, Bs...)
20+
Xs = (A, Bs...)
21+
rules = ntuple(length(Bs) + 1) do i
22+
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
23+
u = l + size(Xs[i], 2)
24+
dim = u > l + 1 ? (l+1:u) : u
25+
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
26+
# materialize with `copy`
27+
Rule(Ȳ->copy(selectdim(Ȳ, 2, dim)))
28+
end
29+
return Y, rules
30+
end
31+
32+
#####
33+
##### `vcat`
34+
#####
35+
36+
function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
37+
Y = vcat(A, Bs...)
38+
n = size(A, 1)
39+
∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n)))
40+
∂Bs = ntuple(length(Bs)) do i
41+
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
42+
u = l + size(Bs[i], 1)
43+
Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u)))
44+
end
45+
return Y, (∂A, ∂Bs...)
46+
end
47+
48+
#####
49+
##### `fill`
50+
#####
51+
52+
function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
53+
return fill(value, dims), (Rule(sum), DNERule())
54+
end
55+
56+
function rrule(::typeof(fill), value::Any, dims::Int...)
57+
return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...)
58+
end

test/rules/array.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
@testset "reshape" begin
2+
rng = MersenneTwister(1)
3+
A = randn(rng, 4, 5)
4+
B, (dA, dd) = rrule(reshape, A, (5, 4))
5+
@test B == reshape(A, (5, 4))
6+
@test dd isa ChainRules.DNERule
7+
= randn(rng, 4, 5)
8+
= dA(Ȳ)
9+
@test== reshape(Ȳ, (5, 4))
10+
11+
B, (dA, dd1, dd2) = rrule(reshape, A, 5, 4)
12+
@test B == reshape(A, 5, 4)
13+
@test dd1 isa ChainRules.DNERule
14+
@test dd2 isa ChainRules.DNERule
15+
= randn(rng, 4, 5)
16+
= dA(Ȳ)
17+
@test== reshape(Ȳ, 5, 4)
18+
end
19+
20+
@testset "hcat" begin
21+
rng = MersenneTwister(2)
22+
A = randn(rng, 3, 2)
23+
B = randn(rng, 3)
24+
C = randn(rng, 3, 3)
25+
H, (dA, dB, dC) = rrule(hcat, A, B, C)
26+
@test H == hcat(A, B, C)
27+
= randn(rng, 3, 6)
28+
@test dA(H̄) view(H̄, :, 1:2)
29+
@test dB(H̄) view(H̄, :, 3)
30+
@test dC(H̄) view(H̄, :, 4:6)
31+
end
32+
33+
@testset "vcat" begin
34+
rng = MersenneTwister(3)
35+
A = randn(rng, 2, 4)
36+
B = randn(rng, 1, 4)
37+
C = randn(rng, 3, 4)
38+
V, (dA, dB, dC) = rrule(vcat, A, B, C)
39+
@test V == vcat(A, B, C)
40+
= randn(rng, 6, 4)
41+
@test dA(V̄) view(V̄, 1:2, :)
42+
@test dB(V̄) view(V̄, 3:3, :)
43+
@test dC(V̄) view(V̄, 4:6, :)
44+
end
45+
46+
@testset "fill" begin
47+
y, (dv, dd) = rrule(fill, 44, 4)
48+
@test y == [44, 44, 44, 44]
49+
@test dd isa ChainRules.DNERule
50+
@test dv(ones(Int, 4)) == 4
51+
52+
y, (dv, dd) = rrule(fill, 2.0, (3, 3, 3))
53+
@test y == fill(2.0, (3, 3, 3))
54+
@test dd isa ChainRules.DNERule
55+
@test dv(ones(3, 3, 3)) 27.0
56+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include("test_util.jl")
1414
include("rules.jl")
1515
@testset "rules" begin
1616
include(joinpath("rules", "base.jl"))
17+
include(joinpath("rules", "array.jl"))
1718
@testset "linalg" begin
1819
include(joinpath("rules", "linalg", "dense.jl"))
1920
include(joinpath("rules", "linalg", "diagonal.jl"))

0 commit comments

Comments
 (0)