Skip to content

Commit 675feca

Browse files
Add Tridiagonal construction rule
1 parent 87f4996 commit 675feca

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
268268
end
269269
return y, logdet_pullback
270270
end
271+
272+
#####
273+
##### Tridiagonal
274+
#####
275+
276+
function rrule(::Type{Tridiagonal}, dl, d, du)
277+
y = Tridiagonal(dl, d, du)
278+
@views function ∇Tridiagonal(∂y)
279+
return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y),
280+
diag(∂y[1:(end - 1), 2:end]))
281+
end
282+
return y, ∇Tridiagonal
283+
end

0 commit comments

Comments
 (0)