Skip to content

Commit ebb7b2a

Browse files
aravindh-krishnamoorthyDilumAluthgestevengj
authored
Implement cbrt(A::AbstractMatrix{<:Real}) (#50661)
Co-authored-by: Dilum Aluthge <dilum@aluthge.com> Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
1 parent ae44b98 commit ebb7b2a

File tree

7 files changed

+202
-4
lines changed

7 files changed

+202
-4
lines changed

docs/src/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ as well as whether hooks to various optimized methods for them in LAPACK are ava
184184

185185
| Matrix type | `+` | `-` | `*` | `\` | Other functions with optimized methods |
186186
|:----------------------------- |:--- |:--- |:--- |:--- |:----------------------------------------------------------- |
187-
| [`Symmetric`](@ref) | | | | MV | [`inv`](@ref), [`sqrt`](@ref), [`exp`](@ref) |
188-
| [`Hermitian`](@ref) | | | | MV | [`inv`](@ref), [`sqrt`](@ref), [`exp`](@ref) |
187+
| [`Symmetric`](@ref) | | | | MV | [`inv`](@ref), [`sqrt`](@ref), [`cbrt`](@ref), [`exp`](@ref) |
188+
| [`Hermitian`](@ref) | | | | MV | [`inv`](@ref), [`sqrt`](@ref), [`cbrt`](@ref), [`exp`](@ref) |
189189
| [`UpperTriangular`](@ref) | | | MV | MV | [`inv`](@ref), [`det`](@ref), [`logdet`](@ref) |
190190
| [`UnitUpperTriangular`](@ref) | | | MV | MV | [`inv`](@ref), [`det`](@ref), [`logdet`](@ref) |
191191
| [`LowerTriangular`](@ref) | | | MV | MV | [`inv`](@ref), [`det`](@ref), [`logdet`](@ref) |
@@ -516,6 +516,7 @@ Base.:^(::AbstractMatrix, ::Number)
516516
Base.:^(::Number, ::AbstractMatrix)
517517
LinearAlgebra.log(::StridedMatrix)
518518
LinearAlgebra.sqrt(::StridedMatrix)
519+
LinearAlgebra.cbrt(::AbstractMatrix{<:Real})
519520
LinearAlgebra.cos(::StridedMatrix{<:Real})
520521
LinearAlgebra.sin(::StridedMatrix{<:Real})
521522
LinearAlgebra.sincos(::StridedMatrix{<:Real})

src/LinearAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ module LinearAlgebra
99

1010
import Base: \, /, *, ^, +, -, ==
1111
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
12-
asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, collect, conj, convert, copy,
13-
copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
12+
asin, asinh, atan, atanh, axes, big, broadcast, cbrt, ceil, cis, collect, conj, convert,
13+
copy, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
1414
getindex, hcat, getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle,
1515
kron, kron!, length, log, map, ndims, one, oneunit, parent, permutedims,
1616
power_by_squaring, promote_rule, real, sec, sech, setindex!, show, similar, sin,

src/dense.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,54 @@ end
907907
sqrt(A::AdjointAbsMat) = adjoint(sqrt(parent(A)))
908908
sqrt(A::TransposeAbsMat) = transpose(sqrt(parent(A)))
909909

910+
"""
911+
cbrt(A::AbstractMatrix{<:Real})
912+
913+
Computes the real-valued cube root of a real-valued matrix `A`. If `T = cbrt(A)`, then
914+
we have `T*T*T ≈ A`, see example given below.
915+
916+
If `A` is symmetric, i.e., of type `HermOrSym{<:Real}`, then ([`eigen`](@ref)) is used to
917+
find the cube root. Otherwise, a specialized version of the p-th root algorithm [^S03] is
918+
utilized, which exploits the real-valued Schur decomposition ([`schur`](@ref))
919+
to compute the cube root.
920+
921+
[^S03]:
922+
923+
Matthew I. Smith, "A Schur Algorithm for Computing Matrix pth Roots",
924+
SIAM Journal on Matrix Analysis and Applications, vol. 24, 2003, pp. 971–989.
925+
[doi:10.1137/S0895479801392697](https://doi.org/10.1137/s0895479801392697)
926+
927+
# Examples
928+
```jldoctest
929+
julia> A = [0.927524 -0.15857; -1.3677 -1.01172]
930+
2×2 Matrix{Float64}:
931+
0.927524 -0.15857
932+
-1.3677 -1.01172
933+
934+
julia> T = cbrt(A)
935+
2×2 Matrix{Float64}:
936+
0.910077 -0.151019
937+
-1.30257 -0.936818
938+
939+
julia> T*T*T ≈ A
940+
true
941+
```
942+
"""
943+
function cbrt(A::AbstractMatrix{<:Real})
944+
if checksquare(A) == 0
945+
return copy(A)
946+
elseif issymmetric(A)
947+
return cbrt(Symmetric(A, :U))
948+
else
949+
S = schur(A)
950+
return S.Z * _cbrt_quasi_triu!(S.T) * S.Z'
951+
end
952+
end
953+
954+
# Cube roots of adjoint and transpose matrices
955+
cbrt(A::AdjointAbsMat) = adjoint(cbrt(parent(A)))
956+
cbrt(A::TransposeAbsMat) = transpose(cbrt(parent(A)))
957+
910958
function inv(A::StridedMatrix{T}) where T
911959
checksquare(A)
912960
if istriu(A)

src/diagonal.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,9 @@ for f in (:exp, :cis, :log, :sqrt,
736736
@eval $f(D::Diagonal) = Diagonal($f.(D.diag))
737737
end
738738

739+
# Cube root of a real-valued diagonal matrix
740+
cbrt(A::Diagonal{<:Real}) = Diagonal(cbrt.(A.diag))
741+
739742
function inv(D::Diagonal{T}) where T
740743
Di = similar(D.diag, typeof(inv(oneunit(T))))
741744
for i = 1:length(D.diag)

src/symmetric.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,13 @@ for func in (:log, :sqrt)
812812
end
813813
end
814814

815+
# Cube root of a real-valued symmetric matrix
816+
function cbrt(A::HermOrSym{<:Real})
817+
F = eigen(A)
818+
A = F.vectors * Diagonal(cbrt.(F.values)) * F.vectors'
819+
return A
820+
end
821+
815822
"""
816823
hermitianpart(A, uplo=:U) -> Hermitian
817824

src/triangular.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2537,3 +2537,94 @@ for (tritype, comptritype) in ((:LowerTriangular, :UpperTriangular),
25372537
@eval /(u::TransposeAbsVec, A::$tritype{<:Any,<:Adjoint}) = transpose($comptritype(conj(parent(parent(A)))) \ u.parent)
25382538
@eval /(u::TransposeAbsVec, A::$tritype{<:Any,<:Transpose}) = transpose(transpose(A) \ u.parent)
25392539
end
2540+
2541+
# Cube root of a 2x2 real-valued matrix with complex conjugate eigenvalues and equal diagonal values.
2542+
# Reference [1]: Smith, M. I. (2003). A Schur Algorithm for Computing Matrix pth Roots.
2543+
# SIAM Journal on Matrix Analysis and Applications (Vol. 24, Issue 4, pp. 971–989).
2544+
# https://doi.org/10.1137/s0895479801392697
2545+
function _cbrt_2x2!(A::AbstractMatrix{T}) where {T<:Real}
2546+
@assert checksquare(A) == 2
2547+
@inbounds begin
2548+
(A[1,1] == A[2,2]) || throw(ArgumentError("_cbrt_2x2!: Matrix A must have equal diagonal values."))
2549+
(A[1,2]*A[2,1] < 0) || throw(ArgumentError("_cbrt_2x2!: Matrix A must have complex conjugate eigenvalues."))
2550+
μ = sqrt(-A[1,2]*A[2,1])
2551+
r = cbrt(hypot(A[1,1], μ))
2552+
θ = atan(μ, A[1,1])
2553+
s, c = sincos/3)
2554+
α, β′ = r*c, r*s/µ
2555+
A[1,1] = α
2556+
A[2,2] = α
2557+
A[1,2] = β′*A[1,2]
2558+
A[2,1] = β′*A[2,1]
2559+
end
2560+
return A
2561+
end
2562+
2563+
# Cube root of a quasi upper triangular matrix (output of Schur decomposition)
2564+
# Reference [1]: Smith, M. I. (2003). A Schur Algorithm for Computing Matrix pth Roots.
2565+
# SIAM Journal on Matrix Analysis and Applications (Vol. 24, Issue 4, pp. 971–989).
2566+
# https://doi.org/10.1137/s0895479801392697
2567+
@views function _cbrt_quasi_triu!(A::AbstractMatrix{T}) where {T<:Real}
2568+
m, n = size(A)
2569+
(m == n) || throw(ArgumentError("_cbrt_quasi_triu!: Matrix A must be square."))
2570+
# Cube roots of 1x1 and 2x2 diagonal blocks
2571+
i = 1
2572+
sizes = ones(Int,n)
2573+
S = zeros(T,2,n)
2574+
while i < n
2575+
if !iszero(A[i+1,i])
2576+
_cbrt_2x2!(A[i:i+1,i:i+1])
2577+
mul!(S[1:2,i:i+1], A[i:i+1,i:i+1], A[i:i+1,i:i+1])
2578+
sizes[i] = 2
2579+
sizes[i+1] = 0
2580+
i += 2
2581+
else
2582+
A[i,i] = cbrt(A[i,i])
2583+
S[1,i] = A[i,i]*A[i,i]
2584+
i += 1
2585+
end
2586+
end
2587+
if i == n
2588+
A[n,n] = cbrt(A[n,n])
2589+
S[1,n] = A[n,n]*A[n,n]
2590+
end
2591+
# Algorithm 4.3 in Reference [1]
2592+
Δ = I(4)
2593+
M_L₀ = zeros(T,4,4)
2594+
M_L₁ = zeros(T,4,4)
2595+
M_Bᵢⱼ⁽⁰⁾ = zeros(T,2,2)
2596+
M_Bᵢⱼ⁽¹⁾ = zeros(T,2,2)
2597+
for k = 1:n-1
2598+
for i = 1:n-k
2599+
if sizes[i] == 0 || sizes[i+k] == 0 continue end
2600+
k₁, k₂ = i+1+(sizes[i+1]==0), i+k-1
2601+
i₁, i₂, j₁, j₂, s₁, s₂ = i, i+sizes[i]-1, i+k, i+k+sizes[i+k]-1, sizes[i], sizes[i+k]
2602+
L₀ = M_L₀[1:s₁*s₂,1:s₁*s₂]
2603+
L₁ = M_L₁[1:s₁*s₂,1:s₁*s₂]
2604+
Bᵢⱼ⁽⁰⁾ = M_Bᵢⱼ⁽⁰⁾[1:s₁, 1:s₂]
2605+
Bᵢⱼ⁽¹⁾ = M_Bᵢⱼ⁽¹⁾[1:s₁, 1:s₂]
2606+
# Compute Bᵢⱼ⁽⁰⁾ and Bᵢⱼ⁽¹⁾
2607+
mul!(Bᵢⱼ⁽⁰⁾, A[i₁:i₂,k₁:k₂], A[k₁:k₂,j₁:j₂])
2608+
# Retreive Rᵢ,ᵢ₊ₖ as A[i+k,i]'
2609+
mul!(Bᵢⱼ⁽¹⁾, A[i₁:i₂,k₁:k₂], A[j₁:j₂,k₁:k₂]')
2610+
# Solve Uᵢ,ᵢ₊ₖ using Reference [1, (4.10)]
2611+
kron!(L₀, Δ[1:s₂,1:s₂], S[1:s₁,i₁:i₂])
2612+
L₀ .+= kron!(L₁, A[j₁:j₂,j₁:j₂]', A[i₁:i₂,i₁:i₂])
2613+
L₀ .+= kron!(L₁, S[1:s₂,j₁:j₂]', Δ[1:s₁,1:s₁])
2614+
mul!(A[i₁:i₂,j₁:j₂], A[i₁:i₂,i₁:i₂], Bᵢⱼ⁽⁰⁾, -1.0, 1.0)
2615+
A[i₁:i₂,j₁:j₂] .-= Bᵢⱼ⁽¹⁾
2616+
ldiv!(lu!(L₀), A[i₁:i₂,j₁:j₂][:])
2617+
# Compute and store Rᵢ,ᵢ₊ₖ' in A[i+k,i]
2618+
mul!(Bᵢⱼ⁽⁰⁾, A[i₁:i₂,i₁:i₂], A[i₁:i₂,j₁:j₂], 1.0, 1.0)
2619+
mul!(Bᵢⱼ⁽⁰⁾, A[i₁:i₂,j₁:j₂], A[j₁:j₂,j₁:j₂], 1.0, 1.0)
2620+
A[j₁:j₂,i₁:i₂] .= Bᵢⱼ⁽⁰⁾'
2621+
end
2622+
end
2623+
# Make quasi triangular
2624+
for j=1:n for i=j+1+(sizes[j]==2):n A[i,j] = 0 end end
2625+
return A
2626+
end
2627+
2628+
# Cube roots of real-valued triangular matrices
2629+
cbrt(A::UpperTriangular{T}) where {T<:Real} = UpperTriangular(_cbrt_quasi_triu!(Matrix{T}(A)))
2630+
cbrt(A::LowerTriangular{T}) where {T<:Real} = LowerTriangular(_cbrt_quasi_triu!(Matrix{T}(A'))')

test/dense.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,4 +1232,52 @@ Base.:+(x::TypeWithZero, ::TypeWithoutZero) = x
12321232
@test diagm(0 => [TypeWithoutZero()]) isa Matrix{TypeWithZero}
12331233
end
12341234

1235+
@testset "cbrt(A::AbstractMatrix{T})" begin
1236+
N = 10
1237+
1238+
# Non-square
1239+
A = randn(N,N+2)
1240+
@test_throws DimensionMismatch cbrt(A)
1241+
1242+
# Real valued diagonal
1243+
D = Diagonal(randn(N))
1244+
T = cbrt(D)
1245+
@test T*T*T D
1246+
@test eltype(D) == eltype(T)
1247+
# Real valued triangular
1248+
U = UpperTriangular(randn(N,N))
1249+
T = cbrt(U)
1250+
@test T*T*T U
1251+
@test eltype(U) == eltype(T)
1252+
L = LowerTriangular(randn(N,N))
1253+
T = cbrt(L)
1254+
@test T*T*T L
1255+
@test eltype(L) == eltype(T)
1256+
# Real valued symmetric
1257+
S = (A -> (A+A')/2)(randn(N,N))
1258+
T = cbrt(Symmetric(S,:U))
1259+
@test T*T*T S
1260+
@test eltype(S) == eltype(T)
1261+
# Real valued symmetric
1262+
S = (A -> (A+A')/2)(randn(N,N))
1263+
T = cbrt(Symmetric(S,:L))
1264+
@test T*T*T S
1265+
@test eltype(S) == eltype(T)
1266+
# Real valued Hermitian
1267+
S = (A -> (A+A')/2)(randn(N,N))
1268+
T = cbrt(Hermitian(S,:U))
1269+
@test T*T*T S
1270+
@test eltype(S) == eltype(T)
1271+
# Real valued Hermitian
1272+
S = (A -> (A+A')/2)(randn(N,N))
1273+
T = cbrt(Hermitian(S,:L))
1274+
@test T*T*T S
1275+
@test eltype(S) == eltype(T)
1276+
# Real valued arbitrary
1277+
A = randn(N,N)
1278+
T = cbrt(A)
1279+
@test T*T*T A
1280+
@test eltype(A) == eltype(T)
1281+
end
1282+
12351283
end # module TestDense

0 commit comments

Comments
 (0)