Skip to content

Commit 53f1eb8

Browse files
stevengjdkarrasch
andauthored
bugfix for dot of Hermitian{noncommutative} (#52333)
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent 150c1ad commit 53f1eb8

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ function triu(A::Symmetric, k::Integer=0)
453453
end
454454
end
455455

456-
for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:Hermitian, :adjoint, :real)]
456+
for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Union{Real,Complex}}), :adjoint, :real)]
457457
@eval begin
458458
function dot(A::$T, B::$T)
459459
n = size(A, 2)

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ module TestSymmetric
44

55
using Test, LinearAlgebra, Random
66

7+
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
8+
9+
isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl"))
10+
using .Main.Quaternions
11+
712
Random.seed!(1010)
813

914
@testset "Pauli σ-matrices: " for σ in map(Hermitian,
@@ -462,6 +467,17 @@ end
462467
end
463468
end
464469

470+
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
471+
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
472+
@testset "dot Hermitian quaternion #52318" begin
473+
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
474+
@test A == Hermitian(A) && B == Hermitian(B)
475+
@test dot(A, B) dot(Hermitian(A), Hermitian(B))
476+
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
477+
@test A == Symmetric(A) && B == Symmetric(B)
478+
@test dot(A, B) dot(Symmetric(A), Symmetric(B))
479+
end
480+
465481
#Issue #7647: test xsyevr, xheevr, xstevr drivers.
466482
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
467483
(Symmetric(diagm(0 => 1.0:3.0)),

test/testhelpers/Quaternions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3
2020
Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3))
2121
Base.abs(q::Quaternion) = sqrt(abs2(q))
2222
Base.real(::Type{Quaternion{T}}) where {T} = T
23+
Base.real(q::Quaternion) = q.s
2324
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
2425
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
2526
Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))
@@ -33,7 +34,9 @@ Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*
3334
q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1,
3435
q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s)
3536
Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r)
36-
Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity
37+
Base.:(*)(q::Quaternion, r::Bool) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) # remove method ambiguity
38+
Base.:(*)(r::Real, q::Quaternion) = q * r
39+
Base.:(*)(r::Bool, q::Quaternion) = q * r # remove method ambiguity
3740
Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w))
3841
Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))
3942

0 commit comments

Comments
 (0)