Skip to content

Commit 03b7d10

Browse files
committed
Port Adjoint, Transpose, and triangular from Nabla
1 parent 55a3923 commit 03b7d10

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/rules/linalg/structured.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,32 @@ rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_ba
1616

1717
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
1818
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
19+
20+
#####
21+
##### `Adjoint`
22+
#####
23+
24+
# TODO: Deal with complex-valued arrays as well
25+
rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint)
26+
rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vecadjoint)
27+
28+
rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint)
29+
rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vecadjoint)
30+
31+
#####
32+
##### `Transpose`
33+
#####
34+
35+
rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose)
36+
rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vectranspose)
37+
38+
rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose)
39+
rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vectranspose)
40+
41+
#####
42+
##### Triangular matrices
43+
#####
44+
45+
rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix)
46+
47+
rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix)

test/rules/linalg/structured.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,16 @@
1818
rng, N = MersenneTwister(123456), 3
1919
rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N)))
2020
end
21+
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
22+
rng = MersenneTwister(32)
23+
n = 5
24+
m = 3
25+
rrule_test(f, randn(rng, m, n), (randn(rng, n, m), randn(rng, n, m)))
26+
rrule_test(f, randn(rng, 1, n), (randn(rng, n), randn(rng, n)))
27+
end
28+
@testset "$T" for T in (UpperTriangular, LowerTriangular)
29+
rng = MersenneTwister(33)
30+
n = 5
31+
rrule_test(T, T(randn(rng, n, n)), (randn(rng, n, n), randn(rng, n, n)))
32+
end
2133
end

0 commit comments

Comments
 (0)