Skip to content

Commit 78f00cb

Browse files
authored
Merge pull request #758 from JuliaDiff/ChrisRackauckas-patch-1
Add Tridiagonal construction rule
2 parents be9c221 + ce288b9 commit 78f00cb

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
267267
end
268268
return y, logdet_pullback
269269
end
270+
271+
#####
272+
##### Tridiagonal
273+
#####
274+
275+
function rrule(::Type{Tridiagonal}, dl, d, du)
276+
y = Tridiagonal(dl, d, du)
277+
function Tridiagonal_pullback(ȳ)
278+
∂y = unthunk(ȳ)
279+
return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
280+
end
281+
return y, Tridiagonal_pullback
282+
end

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,8 @@
161161
end
162162
end
163163
end
164+
165+
@testset "Tridiagonal" begin
166+
test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
167+
end
164168
end

0 commit comments

Comments
 (0)