Skip to content

Commit 1425a11

Browse files
authored
kron for Diagonal uses kron for components (#40596)
1 parent 79920db commit 1425a11

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,7 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
493493
return C
494494
end
495495

496-
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
497-
valA = A.diag; nA = length(valA)
498-
valB = B.diag; nB = length(valB)
499-
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
500-
C = Diagonal(valC)
501-
return @inbounds kron!(C, A, B)
502-
end
496+
kron(A::Diagonal{<:Number}, B::Diagonal{<:Number}) = Diagonal(kron(A.diag, B.diag))
503497

504498
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
505499
Base.require_one_based_indexing(B)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,23 @@ Random.seed!(1)
403403

404404
end
405405

406+
@testset "kron (issue #40595)" begin
407+
# custom array type to test that kron on Diagonal matrices preserves types of the parents if possible
408+
struct KronTestArray{T, N, AT} <: AbstractArray{T, N}
409+
data::AT
410+
end
411+
KronTestArray(data::AbstractArray) = KronTestArray{eltype(data), ndims(data), typeof(data)}(data)
412+
Base.size(A::KronTestArray) = size(A.data)
413+
LinearAlgebra.kron(A::KronTestArray, B::KronTestArray) = KronTestArray(kron(A.data, B.data))
414+
Base.getindex(K::KronTestArray{<:Any,N}, i::Vararg{Int,N}) where {N} = K.data[i...]
415+
416+
A = KronTestArray([1, 2, 3]);
417+
@test kron(A, A) isa KronTestArray
418+
Ad = Diagonal(A);
419+
@test kron(Ad, Ad).diag isa KronTestArray
420+
@test kron(Ad, Ad).diag == kron([1, 2, 3], [1, 2, 3])
421+
end
422+
406423
@testset "svdvals and eigvals (#11120/#11247)" begin
407424
D = Diagonal(Matrix{Float64}[randn(3,3), randn(2,2)])
408425
@test sort([svdvals(D)...;], rev = true) svdvals([D.diag[1] zeros(3,2); zeros(2,3) D.diag[2]])

0 commit comments

Comments
 (0)