Skip to content

Commit a3bcdba

Browse files
authored
Fix righthand side diagonal multiplication (#441)
1 parent f4e6ae3 commit a3bcdba

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

src/host/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ else
212212
m′, n′ = size(B, 1), size(B, 2)
213213
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
214214
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
215-
@. B' = dd * A'
215+
B .= A .* transpose(dd)
216216

217217
B
218218
end
@@ -228,7 +228,7 @@ else
228228
m′, n′ = size(B, 1), size(B, 2)
229229
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
230230
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
231-
@. B' = α * dd * A' + β * B'
231+
B .= α * A .* transpose(dd) + β * B
232232

233233
B
234234
end

test/testsuite/linalg.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,28 @@
154154
end
155155

156156
@testset "mul! + Diagonal" begin
157-
n = 128
158-
d = AT(rand(Float32, n))
159-
D = Diagonal(d)
160-
B = AT(rand(Float32, n, n))
161-
X = AT(zeros(Float32, n, n))
162-
Y = zeros(Float32, n, n)
163-
α = rand(Float32)
164-
β = rand(Float32)
165-
mul!(X, D, B)
166-
mul!(Y, Diagonal(collect(d)), collect(B))
167-
@test collect(X) Y
168-
mul!(X, D, B, α, β)
169-
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
170-
@test collect(X) Y
171-
mul!(X, B, D)
172-
mul!(Y, collect(B), Diagonal(collect(d)))
173-
@test collect(X) Y
174-
mul!(X, B, D, α, β)
175-
mul!(Y, collect(B), Diagonal(collect(d)), α, β)
176-
@test collect(X) Y
157+
for elty in (Float32, ComplexF32)
158+
n = 128
159+
d = AT(rand(elty, n))
160+
D = Diagonal(d)
161+
B = AT(rand(elty, n, n))
162+
X = AT(zeros(elty, n, n))
163+
Y = zeros(elty, n, n)
164+
α = rand(elty)
165+
β = rand(elty)
166+
mul!(X, D, B)
167+
mul!(Y, Diagonal(collect(d)), collect(B))
168+
@test collect(X) Y
169+
mul!(X, D, B, α, β)
170+
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
171+
@test collect(X) Y
172+
mul!(X, B, D)
173+
mul!(Y, collect(B), Diagonal(collect(d)))
174+
@test collect(X) Y
175+
mul!(X, B, D, α, β)
176+
mul!(Y, collect(B), Diagonal(collect(d)), α, β)
177+
@test collect(X) Y
178+
end
177179
end
178180

179181
@testset "ldiv! + Diagonal" begin

0 commit comments

Comments
 (0)