diff --git a/test/testhelpers/ImmutableArrays.jl b/test/testhelpers/ImmutableArrays.jl index 8f2d23be..61c78a6d 100644 --- a/test/testhelpers/ImmutableArrays.jl +++ b/test/testhelpers/ImmutableArrays.jl @@ -7,6 +7,8 @@ module ImmutableArrays +using LinearAlgebra + export ImmutableArray "An immutable wrapper type for arrays." @@ -28,4 +30,34 @@ AbstractArray{T,N}(A::ImmutableArray{S,N}) where {S,T,N} = ImmutableArray(Abstra Base.copy(A::ImmutableArray) = ImmutableArray(copy(A.data)) Base.zero(A::ImmutableArray) = ImmutableArray(zero(A.data)) +Base.:(-)(A::ImmutableArray) = ImmutableArray(-A.data) +Base.:(+)(A::ImmutableArray, B::ImmutableArray) = ImmutableArray(A.data + B.data) +Base.:(-)(A::ImmutableArray, B::ImmutableArray) = ImmutableArray(A.data - B.data) + +Base.:(*)(A::ImmutableArray, x::Number) = ImmutableArray(A.data * x) +Base.:(*)(x::Number, A::ImmutableArray) = ImmutableArray(x * A.data) + +Base.:(*)(A::ImmutableArray, B::ImmutableArray) = ImmutableArray(A.data * B.data) + +function LinearAlgebra.eigen(S::SymTridiagonal{T, <:ImmutableArray{T,1}}) where {T} + # Use the underlying data for the eigen computation + S2 = SymTridiagonal(diag(S), diag(S,1)) + eigvals, eigvecs = eigen(S2) + return Eigen(ImmutableArray(eigvals), ImmutableArray(eigvecs)) +end + +function LinearAlgebra.eigen(S::Symmetric{T, <:ImmutableArray{T,2}}) where {T<:Real} + # Use the underlying data for the eigen computation + S2 = Symmetric(parent(S).data) + eigvals, eigvecs = eigen(S2) + return Eigen(ImmutableArray(eigvals), ImmutableArray(eigvecs)) +end + +function LinearAlgebra.eigen(S::Hermitian{T, <:ImmutableArray{T,2}}) where {T<:Union{Real,Complex}} + # Use the underlying data for the eigen computation + S2 = Hermitian(parent(S).data) + eigvals, eigvecs = eigen(S2) + return Eigen(ImmutableArray(eigvals), ImmutableArray(eigvecs)) +end + end diff --git a/test/tridiag.jl b/test/tridiag.jl index 9a90e31b..97777a37 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1198,4 +1198,32 @@ end @test_throws BoundsError S[LinearAlgebra.BandIndex(0,size(S,1)+1)] end +@testset "special functions" begin + _dv = Float64[1,-2,3,-4] + _ev = Float64[1,2,3] + @testset "$(typeof(dv))" for (dv, ev) in ((_dv, _ev), ImmutableArray.((_dv, _ev))) + dl = -ev + T = Tridiagonal(dl, dv, ev) + MT = Matrix(T) + S = SymTridiagonal(dv, ev) + MS = Matrix(S) + + @testset for f in Any[sin, cos, tan, + asin, acos, atan, + sinh, cosh, tanh, + asinh, acosh, atanh, + exp, log, sqrt, cbrt, + ] + @test f(T) ≈ f(MT) + @test f(S) ≈ f(MS) + end + for (ST, MST) in ((S, MS), (T, MT)) + sT, cT = sincos(ST) + sMT, cMT = sincos(MST) + @test sT ≈ sMT + @test cT ≈ cMT + end + end +end + end # module TestTridiagonal