Skip to content

Commit abfe271

Browse files
authored
Some Adjoint and Transpose stuff (#324)
* Extend Adjoint / Transpose * Allow Composite cotangents * Remove parent implementations for now * Remove TODO comment * Account for transposing adjoints * Test transpose * Loosen restriction * Bump patch
1 parent 22eddb4 commit abfe271

File tree

3 files changed

+72
-37
lines changed

3 files changed

+72
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.38"
3+
version = "0.7.39"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -90,64 +90,55 @@ end
9090
##### `Adjoint`
9191
#####
9292

93-
# ✖️✖️✖️TODO: Deal with complex-valued arrays as well
94-
function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real})
95-
function Adjoint_pullback(ȳ)
96-
return (NO_FIELDS, adjoint(ȳ))
97-
end
93+
function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number})
94+
Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
95+
Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ))
9896
return Adjoint(A), Adjoint_pullback
9997
end
10098

101-
function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real})
102-
function Adjoint_pullback(ȳ)
103-
return (NO_FIELDS, vec(adjoint(ȳ)))
104-
end
99+
function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number})
100+
Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
101+
Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ)))
105102
return Adjoint(A), Adjoint_pullback
106103
end
107104

108-
function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real})
109-
function adjoint_pullback(ȳ)
110-
return (NO_FIELDS, adjoint(ȳ))
111-
end
105+
function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number})
106+
adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
107+
adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ))
112108
return adjoint(A), adjoint_pullback
113109
end
114110

115-
function rrule(::typeof(adjoint), A::AbstractVector{<:Real})
116-
function adjoint_pullback(ȳ)
117-
return (NO_FIELDS, vec(adjoint(ȳ)))
118-
end
111+
function rrule(::typeof(adjoint), A::AbstractVector{<:Number})
112+
adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
113+
adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ)))
119114
return adjoint(A), adjoint_pullback
120115
end
121116

122117
#####
123118
##### `Transpose`
124119
#####
125120

126-
function rrule(::Type{<:Transpose}, A::AbstractMatrix)
127-
function Transpose_pullback(ȳ)
128-
return (NO_FIELDS, transpose(ȳ))
129-
end
121+
function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number})
122+
Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
123+
Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ))
130124
return Transpose(A), Transpose_pullback
131125
end
132126

133-
function rrule(::Type{<:Transpose}, A::AbstractVector)
134-
function Transpose_pullback(ȳ)
135-
return (NO_FIELDS, vec(transpose(ȳ)))
136-
end
127+
function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number})
128+
Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
129+
Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ)))
137130
return Transpose(A), Transpose_pullback
138131
end
139132

140-
function rrule(::typeof(transpose), A::AbstractMatrix)
141-
function transpose_pullback(ȳ)
142-
return (NO_FIELDS, transpose(ȳ))
143-
end
133+
function rrule(::typeof(transpose), A::AbstractMatrix{<:Number})
134+
transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent)
135+
transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ))
144136
return transpose(A), transpose_pullback
145137
end
146138

147-
function rrule(::typeof(transpose), A::AbstractVector)
148-
function transpose_pullback(ȳ)
149-
return (NO_FIELDS, vec(transpose(ȳ)))
150-
end
139+
function rrule(::typeof(transpose), A::AbstractVector{<:Number})
140+
transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent))
141+
transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ)))
151142
return transpose(A), transpose_pullback
152143
end
153144

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,55 @@
104104
end
105105
end
106106
end
107-
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
107+
@testset "$f, $T" for
108+
f in (Adjoint, adjoint, Transpose, transpose),
109+
T in (Float64, ComplexF64)
110+
108111
n = 5
109112
m = 3
110-
rrule_test(f, randn(m, n), (randn(n, m), randn(n, m)))
111-
rrule_test(f, randn(1, n), (randn(n), randn(n)))
113+
@testset "$f(::Matrix{$T})" begin
114+
A = randn(T, n, m)
115+
= randn(T, n, m)
116+
Y = f(A)
117+
Ȳ_mat = randn(T, m, n)
118+
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))
119+
120+
rrule_test(f, Ȳ_mat, (A, Ā))
121+
122+
_, pb = rrule(f, A)
123+
@test pb(Ȳ_mat) == pb(Ȳ_composite)
124+
end
125+
126+
@testset "$f(::Vector{$T})" begin
127+
a = randn(T, n)
128+
= randn(T, n)
129+
y = f(a)
130+
ȳ_mat = randn(T, 1, n)
131+
ȳ_composite = Composite{typeof(y)}(parent=collect(f(ȳ_mat)))
132+
133+
rrule_test(f, ȳ_mat, (a, ā))
134+
135+
_, pb = rrule(f, a)
136+
@test pb(ȳ_mat) == pb(ȳ_composite)
137+
end
138+
139+
@testset "$f(::Adjoint{$T, Vector{$T})" begin
140+
a = randn(T, n)'
141+
= randn(T, n)'
142+
y = f(a)
143+
= randn(T, n)
144+
145+
rrule_test(f, ȳ, (a, ā))
146+
end
147+
148+
@testset "$f(::Transpose{$T, Vector{$T})" begin
149+
a = transpose(randn(T, n))
150+
= transpose(randn(T, n))
151+
y = f(a)
152+
= randn(T, n)
153+
154+
rrule_test(f, ȳ, (a, ā))
155+
end
112156
end
113157
@testset "$T" for T in (UpperTriangular, LowerTriangular)
114158
n = 5

0 commit comments

Comments
 (0)