Skip to content

Commit dc4adb0

Browse files
authored
Merge pull request #51 from JuliaDiff/aa/restructure
Minor code movement (NFC) and port some matrix types from Nabla
2 parents 0e551fb + 76ea4b8 commit dc4adb0

File tree

11 files changed

+102
-46
lines changed

11 files changed

+102
-46
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ julia = "^1.0"
1515

1616
[extras]
1717
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920

2021
[targets]
21-
test = ["Test", "FDM"]
22+
test = ["FDM", "Random", "Test"]

src/ChainRules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ include("rules.jl")
1313
include("rules/base.jl")
1414
include("rules/array.jl")
1515
include("rules/broadcast.jl")
16+
include("rules/linalg/utils.jl")
17+
include("rules/linalg/blas.jl")
1618
include("rules/linalg/dense.jl")
17-
include("rules/linalg/diagonal.jl")
18-
include("rules/linalg/symmetric.jl")
19+
include("rules/linalg/structured.jl")
1920
include("rules/linalg/factorization.jl")
20-
include("rules/blas.jl")
2121
include("rules/nanmath.jl")
2222
include("rules/specialfunctions.jl")
2323

File renamed without changes.

src/rules/linalg/diagonal.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/rules/linalg/factorization.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
5959
return
6060
end
6161

62-
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
63-
k = size(X, 1)
64-
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
65-
if i == j
66-
X[i,i] = zero(T)
67-
else
68-
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
69-
end
70-
end
71-
X
72-
end
73-
74-
function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
75-
n, m = size(X)
76-
@inbounds for j = 1:m, i = 1:n
77-
X[i,j] = (i == j) - X[i,j]
78-
end
79-
X
80-
end
81-
82-
function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
83-
@inbounds for i = eachindex(X, Y)
84-
X[i] += Y[i]
85-
end
86-
X
87-
end
88-
8962
#####
9063
##### `cholesky`
9164
#####

src/rules/linalg/structured.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Structured matrices
2+
3+
#####
4+
##### `Diagonal`
5+
#####
6+
7+
rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)
8+
9+
rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)
10+
11+
#####
12+
##### `Symmetric`
13+
#####
14+
15+
rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)
16+
17+
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
18+
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
19+
20+
#####
21+
##### `Adjoint`
22+
#####
23+
24+
# TODO: Deal with complex-valued arrays as well
25+
rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint)
26+
rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vecadjoint)
27+
28+
rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint)
29+
rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vecadjoint)
30+
31+
#####
32+
##### `Transpose`
33+
#####
34+
35+
rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose)
36+
rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vectranspose)
37+
38+
rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose)
39+
rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vectranspose)
40+
41+
#####
42+
##### Triangular matrices
43+
#####
44+
45+
rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix)
46+
47+
rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix)

src/rules/linalg/symmetric.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/rules/linalg/utils.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Some utility functions for optimizing linear algebra operations that aren't specific
2+
# to any particular rule definition
3+
4+
# F .* (X - X'), overwrites X
5+
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
6+
k = size(X, 1)
7+
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
8+
if i == j
9+
X[i,i] = zero(T)
10+
else
11+
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
12+
end
13+
end
14+
X
15+
end
16+
17+
# I - X, overwrites X
18+
function _eyesubx!(X::AbstractMatrix)
19+
n, m = size(X)
20+
@inbounds for j = 1:m, i = 1:n
21+
X[i,j] = (i == j) - X[i,j]
22+
end
23+
X
24+
end
25+
26+
# X + Y, overwrites X
27+
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
28+
@inbounds for i = eachindex(X, Y)
29+
X[i] += Y[i]
30+
end
31+
X
32+
end

test/rules/linalg/diagonal.jl renamed to test/rules/linalg/structured.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "diagonal" begin
1+
@testset "Structured Matrices" begin
22
@testset "Diagonal" begin
33
rng, N = MersenneTwister(123456), 3
44
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
@@ -14,4 +14,20 @@
1414
rrule_test(diag, randn(rng, N), (randn(rng, N, N), Diagonal(randn(rng, N))))
1515
rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), Diagonal(randn(rng, N))))
1616
end
17+
@testset "Symmetric" begin
18+
rng, N = MersenneTwister(123456), 3
19+
rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N)))
20+
end
21+
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
22+
rng = MersenneTwister(32)
23+
n = 5
24+
m = 3
25+
rrule_test(f, randn(rng, m, n), (randn(rng, n, m), randn(rng, n, m)))
26+
rrule_test(f, randn(rng, 1, n), (randn(rng, n), randn(rng, n)))
27+
end
28+
@testset "$T" for T in (UpperTriangular, LowerTriangular)
29+
rng = MersenneTwister(33)
30+
n = 5
31+
rrule_test(T, T(randn(rng, n, n)), (randn(rng, n, n), randn(rng, n, n)))
32+
end
1733
end

test/rules/linalg/symmetric.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)