Skip to content

Commit 314b08a

Browse files
authored
Add rrules for binary linear algebra operations (#29)
1 parent dc4adb0 commit 314b08a

File tree

3 files changed

+142
-1
lines changed

3 files changed

+142
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1010

1111
[compat]
1212
Cassette = "^0.2"
13-
FDM = "^0.5"
13+
FDM = "^0.6"
1414
julia = "^1.0"
1515

1616
[extras]

src/rules/linalg/dense.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
using LinearAlgebra: AbstractTriangular
2+
3+
# Matrix wrapper types that we know are square and are thus potentially invertible. For
4+
# these we can use simpler definitions for `/` and `\`.
5+
const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
6+
17
#####
28
##### `sum`
39
#####
@@ -69,3 +75,74 @@ end
6975
frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx))))
7076

7177
rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))
78+
79+
#####
80+
##### `*`
81+
#####
82+
83+
function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
84+
return A * B, (Rule(Ȳ ->* B'), Rule(Ȳ -> A' * Ȳ))
85+
end
86+
87+
#####
88+
##### `/`
89+
#####
90+
91+
function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real}
92+
Y = A / B
93+
S = T.name.wrapper
94+
∂A = Rule(Ȳ ->/ B')
95+
∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B')))
96+
return Y, (∂A, ∂B)
97+
end
98+
99+
function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
100+
Aᵀ, dA = rrule(adjoint, A)
101+
Bᵀ, dB = rrule(adjoint, B)
102+
Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ)
103+
C, dC = rrule(adjoint, Cᵀ)
104+
∂A = Rule(dAdAᵀdC)
105+
∂B = Rule(dAdBᵀdC)
106+
return C, (∂A, ∂B)
107+
end
108+
109+
#####
110+
##### `\`
111+
#####
112+
113+
function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real}
114+
Y = A \ B
115+
S = T.name.wrapper
116+
∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y'))
117+
∂B = Rule(Ȳ -> A' \ Ȳ)
118+
return Y, (∂A, ∂B)
119+
end
120+
121+
function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
122+
Y = A \ B
123+
∂A = Rule() do
124+
= A' \
125+
= -* Y'
126+
_add!(Ā, (B - A * Y) *' / A')
127+
_add!(Ā, A' \ Y * (Ȳ' -'A))
128+
129+
end
130+
∂B = Rule(Ȳ -> A' \ Ȳ)
131+
return Y, (∂A, ∂B)
132+
end
133+
134+
#####
135+
##### `norm`
136+
#####
137+
138+
function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2)
139+
y = norm(A, p)
140+
u = y^(1-p)
141+
∂A = Rule(ȳ ->.* u .* abs.(A).^p ./ A)
142+
∂p = Rule(ȳ ->* (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p)
143+
return y, (∂A, ∂p)
144+
end
145+
146+
function rrule(::typeof(norm), x::Real, p::Real=2)
147+
return norm(x, p), (Rule(ȳ ->* sign(x)), Rule(_ -> zero(x)))
148+
end

test/rules/linalg/dense.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,68 @@ end
7070
frule_test(tr, (randn(rng, N, N), randn(rng, N, N)))
7171
rrule_test(tr, randn(rng), (randn(rng, N, N), randn(rng, N, N)))
7272
end
73+
@testset "*" begin
74+
rng = MersenneTwister(123456)
75+
dims = [3,4,5]
76+
for n in dims, m in dims, p in dims
77+
n > 3 && n == m == p && continue # don't need to test square case multiple times
78+
A = randn(rng, m, n)
79+
B = randn(rng, n, p)
80+
= randn(rng, m, p)
81+
rrule_test(*, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, n, p)))
82+
end
83+
end
84+
@testset "$f" for f in [/, \]
85+
rng = MersenneTwister(42)
86+
for n in 3:5, m in 3:5
87+
A = randn(rng, m, n)
88+
B = randn(rng, m, n)
89+
= randn(rng, size(f(A, B)))
90+
rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n)))
91+
end
92+
# Vectors
93+
x = randn(rng, 10)
94+
y = randn(rng, 10)
95+
= randn(rng, size(f(x, y))...)
96+
rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10)))
97+
if f == (/)
98+
@testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
99+
RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10)))
100+
Y = randn(rng, 5, 10)
101+
= randn(rng, size(f(Y, RHS))...)
102+
rrule_test(f, Ȳ, (Y, randn(rng, size(Y))), (RHS, randn(rng, size(RHS))))
103+
end
104+
else
105+
@testset "$T on LHS" for T in (Diagonal, UpperTriangular, LowerTriangular)
106+
LHS = T(randn(rng, T == Diagonal ? 10 : (10, 10)))
107+
y = randn(rng, 10)
108+
= randn(rng, size(f(LHS, y))...)
109+
rrule_test(f, ȳ, (LHS, randn(rng, size(LHS))), (y, randn(rng, 10)))
110+
Y = randn(rng, 10, 10)
111+
= randn(rng, 10, 10)
112+
rrule_test(f, Ȳ, (LHS, randn(rng, size(LHS))), (Y, randn(rng, size(Y))))
113+
end
114+
@testset "Matrix $f Vector" begin
115+
X = randn(rng, 10, 4)
116+
y = randn(rng, 10)
117+
= randn(rng, size(f(X, y))...)
118+
rrule_test(f, ȳ, (X, randn(rng, size(X))), (y, randn(rng, 10)))
119+
end
120+
@testset "Vector $f Matrix" begin
121+
x = randn(rng, 10)
122+
Y = randn(rng, 10, 4)
123+
= randn(rng, size(f(x, Y))...)
124+
rrule_test(f, ȳ, (x, randn(rng, size(x))), (Y, randn(rng, size(Y))))
125+
end
126+
end
127+
end
128+
@testset "norm" begin
129+
rng = MersenneTwister(3)
130+
for dims in [(), (5,), (3, 2), (7, 3, 2)]
131+
A = randn(rng, dims...)
132+
p = randn(rng)
133+
= randn(rng)
134+
rrule_test(norm, ȳ, (A, randn(rng, dims...)), (p, randn(rng)))
135+
end
136+
end
73137
end

0 commit comments

Comments
 (0)