Skip to content

Commit c741bd3

Browse files
stdlib: faster kronecker product between hermitian and symmetric matrices (#53186)
The kronecker product between complex hermitian matrices is again hermitian, so it can be computed much faster by only doing the upper (or lower) triangular. As @andreasnoack will surely notice, this only true for types where `conj(a*b) == conj(a)*conj(b)`, so I'm restricting the function to act only on real and complex numbers. In the symmetric case, however, no additional assumption is needed, so I'm letting it act on anything. Benchmarking showed that the code is roughly 2 times as fast as the vanilla kronecker product, as expected. The fastest case was always the UU case, and the slowest the LU case. The code I used is below ```julia using LinearAlgebra using BenchmarkTools using Quaternions randrmatrix(d, uplo = :U) = Hermitian(randn(Float64, d, d), uplo) randcmatrix(d, uplo = :U) = Hermitian(randn(ComplexF64, d, d), uplo) randsmatrix(d, uplo = :U) = Symmetric(randn(ComplexF64, d, d), uplo) randqmatrix(d, uplo = :U) = Symmetric(randn(QuaternionF64, d, d), uplo) dima = 69 dimb = 71 for randmatrix in [randrmatrix, randcmatrix, randsmatrix, randqmatrix] for auplo in [:U, :L] for buplo in [:U, :L] a = randmatrix(dima, auplo) b = randmatrix(dimb, buplo) c = kron(a,b) therm = @belapsed kron!($c, $a, $b) C = Matrix(c) A = Matrix(a) B = Matrix(b) told = @belapsed kron!($C, $A, $B) @show told/therm end end end ``` Weirdly enough, I got this expected speedup in one of my machines, but when running the benchmark in another I got roughly the same time. I guess that's a bug with `BechmarkTools`, because that's not consistent with the times I get running the functions individually, out of the loop. Another issue is that although I added a couple of tests, I couldn't get them to run. Perhaps someone here can tell me what's going on? I could run the tests from LinearAlgebra, it's just that editing the files made no difference to what was being run. I did get hundreds of errors from `triangular.jl`, but that's untouched by my code. --------- Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
1 parent 0f7674e commit c741bd3

File tree

5 files changed

+238
-2
lines changed

5 files changed

+238
-2
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,8 @@ julia> reshape(kron(v,w), (length(w), length(v)))
491491
```
492492
"""
493493
function kron(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S}
494-
R = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
495-
return kron!(R, A, B)
494+
C = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
495+
return kron!(C, A, B)
496496
end
497497
function kron(a::AbstractVector{T}, b::AbstractVector{S}) where {T,S}
498498
c = Vector{promote_op(*,T,S)}(undef, length(a)*length(b))

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,130 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
525525
end
526526
end
527527

528+
function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}}
529+
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
530+
C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
531+
return kron!(C, A, B)
532+
end
533+
534+
function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number}
535+
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
536+
C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
537+
return kron!(C, A, B)
538+
end
539+
540+
function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}})
541+
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
542+
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
543+
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
544+
end
545+
_hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo)
546+
return C
547+
end
548+
549+
function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number})
550+
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
551+
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
552+
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
553+
end
554+
_hermkron!(C.data, A.data, B.data, identity, identity, A.uplo, B.uplo)
555+
return C
556+
end
557+
558+
function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR}
559+
n_A = size(A, 1)
560+
n_B = size(B, 1)
561+
@inbounds if Auplo == 'U' && Buplo == 'U'
562+
for j = 1:n_A
563+
jnB = (j - 1) * n_B
564+
for i = 1:(j-1)
565+
Aij = A[i, j]
566+
inB = (i - 1) * n_B
567+
for l = 1:n_B
568+
for k = 1:(l-1)
569+
C[inB+k, jnB+l] = Aij * B[k, l]
570+
C[inB+l, jnB+k] = Aij * conj(B[k, l])
571+
end
572+
C[inB+l, jnB+l] = Aij * real(B[l, l])
573+
end
574+
end
575+
Ajj = real(A[j, j])
576+
for l = 1:n_B
577+
for k = 1:(l-1)
578+
C[jnB+k, jnB+l] = Ajj * B[k, l]
579+
end
580+
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
581+
end
582+
end
583+
elseif Auplo == 'U' && Buplo == 'L'
584+
for j = 1:n_A
585+
jnB = (j - 1) * n_B
586+
for i = 1:(j-1)
587+
Aij = A[i, j]
588+
inB = (i - 1) * n_B
589+
for l = 1:n_B
590+
C[inB+l, jnB+l] = Aij * real(B[l, l])
591+
for k = (l+1):n_B
592+
C[inB+l, jnB+k] = Aij * conj(B[k, l])
593+
C[inB+k, jnB+l] = Aij * B[k, l]
594+
end
595+
end
596+
end
597+
Ajj = real(A[j, j])
598+
for l = 1:n_B
599+
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
600+
for k = (l+1):n_B
601+
C[jnB+l, jnB+k] = Ajj * conj(B[k, l])
602+
end
603+
end
604+
end
605+
elseif Auplo == 'L' && Buplo == 'U'
606+
for j = 1:n_A
607+
jnB = (j - 1) * n_B
608+
Ajj = real(A[j, j])
609+
for l = 1:n_B
610+
for k = 1:(l-1)
611+
C[jnB+k, jnB+l] = Ajj * B[k, l]
612+
end
613+
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
614+
end
615+
for i = (j+1):n_A
616+
conjAij = conj(A[i, j])
617+
inB = (i - 1) * n_B
618+
for l = 1:n_B
619+
for k = 1:(l-1)
620+
C[jnB+k, inB+l] = conjAij * B[k, l]
621+
C[jnB+l, inB+k] = conjAij * conj(B[k, l])
622+
end
623+
C[jnB+l, inB+l] = conjAij * real(B[l, l])
624+
end
625+
end
626+
end
627+
else #if Auplo == 'L' && Buplo == 'L'
628+
for j = 1:n_A
629+
jnB = (j - 1) * n_B
630+
Ajj = real(A[j, j])
631+
for l = 1:n_B
632+
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
633+
for k = (l+1):n_B
634+
C[jnB+k, jnB+l] = Ajj * B[k, l]
635+
end
636+
end
637+
for i = (j+1):n_A
638+
Aij = A[i, j]
639+
inB = (i - 1) * n_B
640+
for l = 1:n_B
641+
C[inB+l, jnB+l] = Aij * real(B[l, l])
642+
for k = (l+1):n_B
643+
C[inB+k, jnB+l] = Aij * B[k, l]
644+
C[inB+l, jnB+k] = Aij * conj(B[k, l])
645+
end
646+
end
647+
end
648+
end
649+
end
650+
end
651+
528652
(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
529653
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))
530654

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,80 @@ for op in (:+, :-)
757757
end
758758
end
759759

760+
function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number}
761+
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
762+
return kron!(C, A, B)
763+
end
764+
765+
function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number}
766+
C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
767+
return kron!(C, A, B)
768+
end
769+
770+
function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number})
771+
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
772+
_triukron!(C.data, A.data, B.data)
773+
return C
774+
end
775+
776+
function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number})
777+
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
778+
_trilkron!(C.data, A.data, B.data)
779+
return C
780+
end
781+
782+
function _triukron!(C, A, B)
783+
n_A = size(A, 1)
784+
n_B = size(B, 1)
785+
@inbounds for j = 1:n_A
786+
jnB = (j - 1) * n_B
787+
for i = 1:(j-1)
788+
Aij = A[i, j]
789+
inB = (i - 1) * n_B
790+
for l = 1:n_B
791+
for k = 1:l
792+
C[inB+k, jnB+l] = Aij * B[k, l]
793+
end
794+
for k = 1:(l-1)
795+
C[inB+l, jnB+k] = zero(eltype(C))
796+
end
797+
end
798+
end
799+
Ajj = A[j, j]
800+
for l = 1:n_B
801+
for k = 1:l
802+
C[jnB+k, jnB+l] = Ajj * B[k, l]
803+
end
804+
end
805+
end
806+
end
807+
808+
function _trilkron!(C, A, B)
809+
n_A = size(A, 1)
810+
n_B = size(B, 1)
811+
@inbounds for j = 1:n_A
812+
jnB = (j - 1) * n_B
813+
Ajj = A[j, j]
814+
for l = 1:n_B
815+
for k = l:n_B
816+
C[jnB+k, jnB+l] = Ajj * B[k, l]
817+
end
818+
end
819+
for i = (j+1):n_A
820+
Aij = A[i, j]
821+
inB = (i - 1) * n_B
822+
for l = 1:n_B
823+
for k = l:n_B
824+
C[inB+k, jnB+l] = Aij * B[k, l]
825+
end
826+
for k = (l+1):n_B
827+
C[inB+l, jnB+k] = zero(eltype(C))
828+
end
829+
end
830+
end
831+
end
832+
end
833+
760834
######################
761835
# BlasFloat routines #
762836
######################

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,28 @@ end
467467
@test dot(symblockml, symblockml) dot(msymblockml, msymblockml)
468468
end
469469
end
470+
471+
@testset "kronecker product of symmetric and Hermitian matrices" begin
472+
for mtype in (Symmetric, Hermitian)
473+
symau = mtype(a, :U)
474+
symal = mtype(a, :L)
475+
msymau = Matrix(symau)
476+
msymal = Matrix(symal)
477+
for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
478+
creal = randn(n, n)/2
479+
cimag = randn(n, n)/2
480+
c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal)
481+
symcu = mtype(c, :U)
482+
symcl = mtype(c, :L)
483+
msymcu = Matrix(symcu)
484+
msymcl = Matrix(symcl)
485+
@test kron(symau, symcu) kron(msymau, msymcu)
486+
@test kron(symau, symcl) kron(msymau, msymcl)
487+
@test kron(symal, symcu) kron(msymal, msymcu)
488+
@test kron(symal, symcl) kron(msymal, msymcl)
489+
end
490+
end
491+
end
470492
end
471493
end
472494

@@ -487,6 +509,7 @@ end
487509
@test S - S == MS - MS
488510
@test S*2 == 2*S == 2*MS
489511
@test S/2 == MS/2
512+
@test kron(S,S) == kron(MS,MS)
490513
end
491514
@testset "mixed uplo" begin
492515
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
@@ -502,6 +525,8 @@ end
502525
MSl = Matrix(Sl)
503526
@test Su + Sl == Sl + Su == MSu + MSl
504527
@test Su - Sl == -(Sl - Su) == MSu - MSl
528+
@test kron(Su,Sl) == kron(MSu,MSl)
529+
@test kron(Sl,Su) == kron(MSl,MSu)
505530
end
506531
end
507532
end
@@ -517,6 +542,16 @@ end
517542
@test dot(A, B) dot(Symmetric(A), Symmetric(B))
518543
end
519544

545+
# let's make sure the analogous bug will not show up with kronecker products
546+
@testset "kron Hermitian quaternion #52318" begin
547+
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
548+
@test A == Hermitian(A) && B == Hermitian(B)
549+
@test kron(A, B) kron(Hermitian(A), Hermitian(B))
550+
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
551+
@test A == Symmetric(A) && B == Symmetric(B)
552+
@test kron(A, B) kron(Symmetric(A), Symmetric(B))
553+
end
554+
520555
#Issue #7647: test xsyevr, xheevr, xstevr drivers.
521556
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
522557
(Symmetric(diagm(0 => 1.0:3.0)),

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ debug && println("Test basic type functionality")
359359
# Binary operations
360360
@test A1 + A2 == M1 + M2
361361
@test A1 - A2 == M1 - M2
362+
@test kron(A1,A2) == kron(M1,M2)
362363

363364
# Triangular-Triangular multiplication and division
364365
@test A1*A2 M1*M2
@@ -1014,6 +1015,7 @@ end
10141015
@test 2\L == 2\B
10151016
@test real(L) == real(B)
10161017
@test imag(L) == imag(B)
1018+
@test kron(L,L) == kron(B,B)
10171019
@test transpose!(MT(copy(A))) == transpose(L) broken=!(A isa Matrix)
10181020
@test adjoint!(MT(copy(A))) == adjoint(L) broken=!(A isa Matrix)
10191021
end
@@ -1035,6 +1037,7 @@ end
10351037
@test 2\U == 2\B
10361038
@test real(U) == real(B)
10371039
@test imag(U) == imag(B)
1040+
@test kron(U,U) == kron(B,B)
10381041
@test transpose!(MT(copy(A))) == transpose(U) broken=!(A isa Matrix)
10391042
@test adjoint!(MT(copy(A))) == adjoint(U) broken=!(A isa Matrix)
10401043
end

0 commit comments

Comments
 (0)