Skip to content

Commit cb816d0

Browse files
authored
Rule for permutedims (#559)
* add rule for permutedims * inference on 1.0 * move invperm * test Diagonal * typo
1 parent 9b69475 commit cb816d0

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/rulesets/Base/array.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,32 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
6565
return reshape(A, dims...), reshape_pullback
6666
end
6767

68+
#####
69+
##### `permutedims`
70+
#####
71+
72+
function rrule(::typeof(permutedims), x::AbstractVector)
73+
project = ProjectTo(x)
74+
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
75+
return permutedims(x), permutedims_pullback_1
76+
end
77+
78+
function rrule(::typeof(permutedims), x::AbstractArray, perm)
79+
pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3])
80+
permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
81+
return permutedims(x, perm), permutedims_back_2
82+
end
83+
84+
function rrule(::typeof(PermutedDimsArray), x::AbstractArray, perm)
85+
pr = ProjectTo(x)
86+
permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
87+
return PermutedDimsArray(x, perm), permutedims_back_3
88+
end
89+
6890
#####
6991
##### `repeat`
7092
#####
93+
7194
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
7295

7396
project_Xs = ProjectTo(xs)

test/rulesets/Base/array.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@ end
4242
test_rrule(reshape, rand(4, 5), 2, :)
4343
end
4444

45+
@testset "permutedims + PermutedDimsArray" begin
46+
test_rrule(permutedims, rand(5))
47+
48+
test_rrule(permutedims, rand(3, 4), (2, 1))
49+
test_rrule(permutedims, Diagonal(rand(5)), (2, 1))
50+
# Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all
51+
52+
@test invperm((3, 1, 2)) != (3, 1, 2)
53+
test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2); check_inferred=VERSION>=v"1.1")
54+
55+
@test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2))
56+
x = rand(2, 3, 4)
57+
dy = rand(4, 2, 3)
58+
@test rrule(permutedims, x, (3, 1, 2))[2](dy)[2] == rrule(PermutedDimsArray, x, (3, 1, 2))[2](dy)[2]
59+
end
60+
4561
@testset "repeat" begin
4662

4763
test_rrule(repeat, rand(4, ))

0 commit comments

Comments
 (0)