Skip to content

Commit d052fc6

Browse files
authored
One and oneunit for triangular matrices (#43576)
1 parent 27eb721 commit d052fc6

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

stdlib/LinearAlgebra/src/special.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ one(D::Diagonal) = Diagonal(one.(D.diag))
315315
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
316316
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
317317
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
318+
for t in (:LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTriangular)
319+
@eval one(A::$t) = $t(one(parent(A)))
320+
@eval oneunit(A::$t) = $t(oneunit(parent(A)))
321+
end
318322

319323
zero(D::Diagonal) = Diagonal(zero.(D.diag))
320324
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,4 +800,28 @@ let A = [0.9999999999999998 4.649058915617843e-16 -1.3149405273715513e-16 9.9959
800800
B = [0.09648289218436859 0.023497875751503007 0.0 0.0; 0.023497875751503007 0.045787575150300804 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
801801
@test sqrt(A*B*A')^2 A*B*A'
802802
end
803+
804+
@testset "one and oneunit for triangular" begin
805+
m = rand(4,4)
806+
function test_one_oneunit_triangular(a)
807+
b = Matrix(a)
808+
@test (@inferred a^1) == b^1
809+
@test (@inferred a^-1) == b^-1
810+
@test one(a) == one(b)
811+
@test one(a)*a == a
812+
@test a*one(a) == a
813+
@test oneunit(a) == oneunit(b)
814+
@test oneunit(a) isa typeof(a)
815+
end
816+
for T in [UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular]
817+
a = T(m)
818+
test_one_oneunit_triangular(a)
819+
end
820+
# more complicated examples
821+
b = UpperTriangular(LowerTriangular(m))
822+
test_one_oneunit_triangular(b)
823+
c = UpperTriangular(Diagonal(rand(2)))
824+
test_one_oneunit_triangular(c)
825+
end
826+
803827
end # module TestTriangular

0 commit comments

Comments
 (0)