@@ -634,16 +634,33 @@ for Tri in (:UpperTriangular, :LowerTriangular)
634
634
end
635
635
636
636
@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
637
- valA = A. diag; nA = length (valA )
638
- valB = B. diag; nB = length (valB )
637
+ valA = A. diag; mA, nA = size (A )
638
+ valB = B. diag; mB, nB = size (B )
639
639
nC = checksquare (C)
640
640
@boundscheck nC == nA* nB ||
641
641
throw (DimensionMismatch (lazy " expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)" ))
642
- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
642
+ zerofilled = false
643
+ if ! (isempty (A) || isempty (B))
644
+ z = A[1 ,1 ] * B[1 ,1 ]
645
+ if haszero (typeof (z))
646
+ # in this case, the zero is unique
647
+ fill! (C, zero (z))
648
+ zerofilled = true
649
+ end
650
+ end
643
651
@inbounds for i = 1 : nA, j = 1 : nB
644
652
idx = (i- 1 )* nB+ j
645
653
C[idx, idx] = valA[i] * valB[j]
646
654
end
655
+ if ! zerofilled
656
+ for j in 1 : nA, i in 1 : mA
657
+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
658
+ for k in 1 : nB, l in 1 : mB
659
+ i == j && k == l && continue
660
+ C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
661
+ end
662
+ end
663
+ end
647
664
return C
648
665
end
649
666
670
687
(mC, nC) = size (C)
671
688
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
672
689
throw (DimensionMismatch (lazy " expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)" ))
673
- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
690
+ zerofilled = false
691
+ if ! (isempty (A) || isempty (B))
692
+ z = A[1 ,1 ] * B[1 ,1 ]
693
+ if haszero (typeof (z))
694
+ # in this case, the zero is unique
695
+ fill! (C, zero (z))
696
+ zerofilled = true
697
+ end
698
+ end
674
699
m = 1
675
700
@inbounds for j = 1 : nA
676
701
A_jj = A[j,j]
681
706
end
682
707
m += (nA - 1 ) * mB
683
708
end
709
+ if ! zerofilled
710
+ # populate the zero elements
711
+ for i in 1 : mA
712
+ i == j && continue
713
+ A_ij = A[i, j]
714
+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
715
+ for k in 1 : nB, l in 1 : nA
716
+ B_lk = B[l, k]
717
+ C[Δrow + l, Δcol + k] = A_ij * B_lk
718
+ end
719
+ end
720
+ end
684
721
m += mB
685
722
end
686
723
return C
@@ -693,17 +730,36 @@ end
693
730
(mC, nC) = size (C)
694
731
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
695
732
throw (DimensionMismatch (lazy " expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)" ))
696
- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
733
+ zerofilled = false
734
+ if ! (isempty (A) || isempty (B))
735
+ z = A[1 ,1 ] * B[1 ,1 ]
736
+ if haszero (typeof (z))
737
+ # in this case, the zero is unique
738
+ fill! (C, zero (z))
739
+ zerofilled = true
740
+ end
741
+ end
697
742
m = 1
698
743
@inbounds for j = 1 : nA
699
744
for l = 1 : mB
700
745
Bll = B[l,l]
701
- for k = 1 : mA
702
- C[m] = A[k ,j] * Bll
746
+ for i = 1 : mA
747
+ C[m] = A[i ,j] * Bll
703
748
m += nB
704
749
end
705
750
m += 1
706
751
end
752
+ if ! zerofilled
753
+ for i in 1 : mA
754
+ A_ij = A[i, j]
755
+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
756
+ for k in 1 : nB, l in 1 : mB
757
+ l == k && continue
758
+ B_lk = B[l, k]
759
+ C[Δrow + l, Δcol + k] = A_ij * B_lk
760
+ end
761
+ end
762
+ end
707
763
m -= nB
708
764
end
709
765
return C
0 commit comments