Skip to content

Commit 4e03986

Browse files
committed
Fix zero elements for block-matrix kron involving Diagonal (#55941)
1 parent 66b620f commit 4e03986

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -634,16 +634,33 @@ for Tri in (:UpperTriangular, :LowerTriangular)
634634
end
635635

636636
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
637-
valA = A.diag; nA = length(valA)
638-
valB = B.diag; nB = length(valB)
637+
valA = A.diag; mA, nA = size(A)
638+
valB = B.diag; mB, nB = size(B)
639639
nC = checksquare(C)
640640
@boundscheck nC == nA*nB ||
641641
throw(DimensionMismatch(lazy"expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
642-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
642+
zerofilled = false
643+
if !(isempty(A) || isempty(B))
644+
z = A[1,1] * B[1,1]
645+
if haszero(typeof(z))
646+
# in this case, the zero is unique
647+
fill!(C, zero(z))
648+
zerofilled = true
649+
end
650+
end
643651
@inbounds for i = 1:nA, j = 1:nB
644652
idx = (i-1)*nB+j
645653
C[idx, idx] = valA[i] * valB[j]
646654
end
655+
if !zerofilled
656+
for j in 1:nA, i in 1:mA
657+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
658+
for k in 1:nB, l in 1:mB
659+
i == j && k == l && continue
660+
C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
661+
end
662+
end
663+
end
647664
return C
648665
end
649666

@@ -670,7 +687,15 @@ end
670687
(mC, nC) = size(C)
671688
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
672689
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
673-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
690+
zerofilled = false
691+
if !(isempty(A) || isempty(B))
692+
z = A[1,1] * B[1,1]
693+
if haszero(typeof(z))
694+
# in this case, the zero is unique
695+
fill!(C, zero(z))
696+
zerofilled = true
697+
end
698+
end
674699
m = 1
675700
@inbounds for j = 1:nA
676701
A_jj = A[j,j]
@@ -681,6 +706,18 @@ end
681706
end
682707
m += (nA - 1) * mB
683708
end
709+
if !zerofilled
710+
# populate the zero elements
711+
for i in 1:mA
712+
i == j && continue
713+
A_ij = A[i, j]
714+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
715+
for k in 1:nB, l in 1:nA
716+
B_lk = B[l, k]
717+
C[Δrow + l, Δcol + k] = A_ij * B_lk
718+
end
719+
end
720+
end
684721
m += mB
685722
end
686723
return C
@@ -693,17 +730,36 @@ end
693730
(mC, nC) = size(C)
694731
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
695732
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
696-
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
733+
zerofilled = false
734+
if !(isempty(A) || isempty(B))
735+
z = A[1,1] * B[1,1]
736+
if haszero(typeof(z))
737+
# in this case, the zero is unique
738+
fill!(C, zero(z))
739+
zerofilled = true
740+
end
741+
end
697742
m = 1
698743
@inbounds for j = 1:nA
699744
for l = 1:mB
700745
Bll = B[l,l]
701-
for k = 1:mA
702-
C[m] = A[k,j] * Bll
746+
for i = 1:mA
747+
C[m] = A[i,j] * Bll
703748
m += nB
704749
end
705750
m += 1
706751
end
752+
if !zerofilled
753+
for i in 1:mA
754+
A_ij = A[i, j]
755+
Δrow, Δcol = (i-1)*mB, (j-1)*nB
756+
for k in 1:nB, l in 1:mB
757+
l == k && continue
758+
B_lk = B[l, k]
759+
C[Δrow + l, Δcol + k] = A_ij * B_lk
760+
end
761+
end
762+
end
707763
m -= nB
708764
end
709765
return C

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,4 +1323,14 @@ end
13231323
@test checkbounds(Bool, D, diagind(D, IndexCartesian()))
13241324
end
13251325

1326+
@testset "zeros in kron with block matrices" begin
1327+
D = Diagonal(1:2)
1328+
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
1329+
@test kron(D, B) == kron(Array(D), B)
1330+
@test kron(B, D) == kron(B, Array(D))
1331+
D2 = Diagonal([ones(2,2), ones(3,3)])
1332+
@test kron(D, D2) == Diagonal([diag(D2); 2diag(D2)])
1333+
@test kron(D2, D) == Diagonal([ones(2,2), fill(2.0,2,2), ones(3,3), fill(2.0,3,3)])
1334+
end
1335+
13261336
end # module TestDiagonal

0 commit comments

Comments
 (0)