Skip to content

Commit b501fda

Browse files
committed
Support Tridiagonal in to_vec
1 parent 119bcd7 commit b501fda

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.31"
3+
version = "0.12.32"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/to_vec.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,26 @@ function to_vec(x::T) where {T<:LinearAlgebra.HermOrSym}
111111
return x_vec, HermOrSym_from_vec
112112
end
113113

114-
function to_vec(X::Diagonal)
115-
x_vec, back = to_vec(Matrix(X))
114+
function to_vec(x::Diagonal)
115+
x_vec, back = to_vec(Matrix(x))
116116
function Diagonal_from_vec(x_vec)
117117
return Diagonal(back(x_vec))
118118
end
119119
return x_vec, Diagonal_from_vec
120120
end
121121

122+
function to_vec(x::Tridiagonal)
123+
x_vec, back = to_vec((x.dl, x.d, x.du))
124+
# Other field (du2) of a Tridiagonal is not part of its value and is really a kind of cache
125+
function Tridiagonal_from_vec(x_vec)
126+
return Tridiagonal(back(x_vec)...)
127+
end
128+
return x_vec, Tridiagonal_from_vec
129+
end
130+
131+
132+
133+
122134
function to_vec(X::Transpose)
123135
x_vec, back = to_vec(Matrix(X))
124136
function Transpose_from_vec(x_vec)

test/to_vec.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ end
8888
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false)
8989
test_to_vec(UpperTriangular(randn(T, 13, 13)))
9090
test_to_vec(Diagonal(randn(T, 7)))
91+
test_to_vec(Tridiagonal(randn(T, 3), randn(T, 4), randn(T, 3)))
92+
9193
test_to_vec(DummyType(randn(T, 2, 9)))
9294
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
9395
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)

0 commit comments

Comments
 (0)