Skip to content

Commit dc7f5a1

Browse files
sethaxenhyrodium
andauthored
Make abs, abs_imag, inv, and / resistant to under/overflow (#122)
* Use hypot in abs and abs_imag * Use approximate check * Add tests against abs/exp under/overflow * Increment patch number * Ensure tests pass on 1.0 * Backport _hypot code for older Julia versions * Mark test as non broken * Implement abs without hypot * Implement abs_imag efficiently * Handle over/underflow for inv and div * Make inv implementation more like Complex's * Test abs_imag for all Julia versions * Add tests for inv under/overflow * Test more abs/abs_imag cases * Simplify tests with isequal * Use isequal * Add isequal * Add under/overflow tests for div * Add missing abs_imag tests * Apply suggestions from code review Co-authored-by: Yuto Horikawa <hyrodium@gmail.com> * Update test/Quaternion.jl Co-authored-by: Yuto Horikawa <hyrodium@gmail.com> * Update test/Quaternion.jl Co-authored-by: Yuto Horikawa <hyrodium@gmail.com> * Increment patch number Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>
1 parent 7e0efc6 commit dc7f5a1

File tree

3 files changed

+128
-6
lines changed

3 files changed

+128
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Quaternions"
22
uuid = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
3-
version = "0.7.3"
3+
version = "0.7.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/Quaternion.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,42 @@ Quaternion{Int64}(1, -2, -3, -4)
136136
```
137137
"""
138138
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
139-
Base.abs(q::Quaternion) = sqrt(abs2(q))
139+
function Base.abs(q::Quaternion)
140+
a = max(abs(q.s), abs(q.v1), abs(q.v2), abs(q.v3))
141+
if isnan(a) && isinf(q)
142+
return typeof(a)(Inf)
143+
elseif iszero(a) || isinf(a)
144+
return a
145+
else
146+
return sqrt(abs2(q / a)) * a
147+
end
148+
end
140149
Base.float(q::Quaternion{T}) where T = convert(Quaternion{float(T)}, q)
141-
abs_imag(q::Quaternion) = sqrt(q.v2 * q.v2 + (q.v1 * q.v1 + q.v3 * q.v3)) # ordered to match abs2
150+
function abs_imag(q::Quaternion)
151+
a = max(abs(q.v1), abs(q.v2), abs(q.v3))
152+
if isnan(a) && (isinf(q.v1) | isinf(q.v2) | isinf(q.v3))
153+
return oftype(a, Inf)
154+
elseif iszero(a) || isinf(a)
155+
return a
156+
else
157+
return sqrt((q.v1 / a)^2 + (q.v2 / a)^2 + (q.v3 / a)^2) * a
158+
end
159+
end
142160
Base.abs2(q::Quaternion) = RealDot.realdot(q,q)
143-
Base.inv(q::Quaternion) = conj(q) / abs2(q)
161+
function Base.inv(q::Quaternion)
162+
if isinf(q)
163+
return quat(
164+
copysign(zero(q.s), q.s),
165+
flipsign(-zero(q.v1), q.v1),
166+
flipsign(-zero(q.v2), q.v2),
167+
flipsign(-zero(q.v3), q.v3),
168+
)
169+
end
170+
a = max(abs(q.s), abs(q.v1), abs(q.v2), abs(q.v3))
171+
p = q / a
172+
iq = conj(p) / (a * abs2(p))
173+
return iq
174+
end
144175

145176
Base.isreal(q::Quaternion) = iszero(q.v1) & iszero(q.v2) & iszero(q.v3)
146177
Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3)
@@ -182,9 +213,28 @@ function Base.:*(q::Quaternion, w::Quaternion)
182213
return Quaternion(s, v1, v2, v3)
183214
end
184215

185-
Base.:/(q::Quaternion, w::Quaternion) = q * inv(w)
216+
function Base.:/(q::Quaternion{T}, w::Quaternion{T}) where T
217+
# handle over/underflow while matching the behavior of /(a::Complex, b::Complex)
218+
a = max(abs(w.s), abs(w.v1), abs(w.v2), abs(w.v3))
219+
if isinf(w)
220+
if isfinite(q)
221+
return quat(
222+
zero(T)*sign(q.s)*sign(w.s),
223+
-zero(T)*sign(q.v1)*sign(w.v1),
224+
-zero(T)*sign(q.v2)*sign(w.v2),
225+
-zero(T)*sign(q.v3)*sign(w.v3),
226+
)
227+
end
228+
return quat(T(NaN), T(NaN), T(NaN), T(NaN))
229+
end
230+
p = w / a
231+
return (q * conj(p)) / RealDot.realdot(w, p)
232+
end
186233

187234
Base.:(==)(q::Quaternion, w::Quaternion) = (q.s == w.s) & (q.v1 == w.v1) & (q.v2 == w.v2) & (q.v3 == w.v3)
235+
function Base.isequal(q::Quaternion, w::Quaternion)
236+
isequal(q.s, w.s) & isequal(q.v1, w.v1) & isequal(q.v2, w.v2) & isequal(q.v3, w.v3)
237+
end
188238

189239
"""
190240
extend_analytic(f, q::Quaternion)

test/Quaternion.jl

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ end
5959
@test Quaternion(1, 2, 3, 4) != Quaternion(1, 2, 3, 5)
6060
end
6161

62+
@testset "isequal" begin
63+
@test isequal(Quaternion(1, 2, 3, 4), Quaternion(1.0, 2.0, 3.0, 4.0))
64+
@test !isequal(Quaternion(1, 2, 3, 4), Quaternion(5, 2, 3, 4))
65+
@test isequal(Quaternion(NaN, -0.0, Inf, -Inf), Quaternion(NaN, -0.0, Inf, -Inf))
66+
@test !isequal(Quaternion(NaN, 0.0, Inf, -Inf), Quaternion(NaN, -0.0, Inf, -Inf))
67+
end
68+
6269
@testset "convert" begin
6370
@test convert(Quaternion{Float64}, 1) === Quaternion(1.0)
6471
@test convert(Quaternion{Float64}, Quaternion(1, 2, 3, 4)) ===
@@ -153,7 +160,31 @@ end
153160
@test conj(conj(q)) === q
154161
@test conj(conj(qnorm)) === qnorm
155162
@test float(Quaternion(1, 2, 3, 4)) === float(Quaternion(1.0, 2.0, 3.0, 4.0))
156-
@test Quaternions.abs_imag(q) == abs(Quaternion(0, q.v1, q.v2, q.v3))
163+
@test Quaternions.abs_imag(q) abs(Quaternion(0, q.v1, q.v2, q.v3))
164+
end
165+
166+
@testset "abs/abs_imag don't over/underflow" begin
167+
for x in [1e-300, 1e300, -1e-300, -1e300]
168+
@test abs(quat(x, 0, 0, 0)) == abs(x)
169+
@test abs(quat(0, x, 0, 0)) == abs(x)
170+
@test abs(quat(0, 0, x, 0)) == abs(x)
171+
@test abs(quat(0, 0, 0, x)) == abs(x)
172+
@test Quaternions.abs_imag(quat(0, x, 0, 0)) == abs(x)
173+
@test Quaternions.abs_imag(quat(0, 0, x, 0)) == abs(x)
174+
@test Quaternions.abs_imag(quat(0, 0, 0, x)) == abs(x)
175+
end
176+
@test isnan(abs(quat(NaN, NaN, NaN, NaN)))
177+
@test abs(quat(NaN, Inf, NaN, NaN)) == Inf
178+
@test abs(quat(-Inf, NaN, NaN, NaN)) == Inf
179+
@test abs(quat(0.0)) == 0.0
180+
@test abs(quat(Inf)) == Inf
181+
@test abs(quat(1, -Inf, 2, 3)) == Inf
182+
@test isnan(Quaternions.abs_imag(quat(0, NaN, NaN, NaN)))
183+
@test Quaternions.abs_imag(quat(0, Inf, NaN, NaN)) == Inf
184+
@test Quaternions.abs_imag(quat(0, NaN, -Inf, NaN)) == Inf
185+
@test Quaternions.abs_imag(quat(0.0)) == 0.0
186+
@test Quaternions.abs_imag(quat(0.0, 0.0, Inf, 0.0)) == Inf
187+
@test Quaternions.abs_imag(quat(0, 1, -Inf, 2)) == Inf
157188
end
158189

159190
@testset "algebraic properties" begin
@@ -171,6 +202,21 @@ end
171202
end
172203
end
173204

205+
@testset "inv does not under/overflow" begin
206+
x = 1e-300
207+
y = inv(x)
208+
@test isequal(inv(quat(x, 0.0, 0.0, 0.0)), quat(y, -0.0, -0.0, -0.0))
209+
@test isequal(inv(quat(0.0, x, 0.0, 0.0)), quat(0.0, -y, -0.0, -0.0))
210+
@test isequal(inv(quat(0.0, 0.0, x, 0.0)), quat(0.0, -0.0, -y, -0.0))
211+
@test isequal(inv(quat(0.0, 0.0, 0.0, x)), quat(0.0, -0.0, -0.0, -y))
212+
@test isequal(inv(quat(y, 0.0, 0.0, 0.0)), quat(x, -0.0, -0.0, -0.0))
213+
@test isequal(inv(quat(0.0, y, 0.0, 0.0)), quat(0.0, -x, -0.0, -0.0))
214+
@test isequal(inv(quat(0.0, 0.0, y, 0.0)), quat(0.0, -0.0, -x, -0.0))
215+
@test isequal(inv(quat(0.0, 0.0, 0.0, y)), quat(0.0, -0.0, -0.0, -x))
216+
@test isequal(inv(quat(-Inf, 1, -2, 3)), quat(-0.0, -0.0, 0.0, -0.0))
217+
@test isequal(inv(quat(1, -2, Inf, 3)), quat(0.0, 0.0, -0.0, -0.0))
218+
end
219+
174220
@testset "isreal" begin
175221
@test isreal(Quaternion(1, 0, 0, 0))
176222
@test !isreal(Quaternion(2, 1, 0, 0))
@@ -275,6 +321,32 @@ end
275321
@test q2 \ q inv(q2) * q
276322
@test q / x x \ q inv(x) * q
277323
end
324+
@testset "no overflow/underflow" begin
325+
@testset for x in [1e-300, 1e300, -1e-300, -1e300]
326+
@test quat(x) / quat(x) == quat(1)
327+
@test quat(x) / quat(0, x, 0, 0) == quat(0, -1, 0, 0)
328+
@test quat(x) / quat(0, 0, x, 0) == quat(0, 0, -1, 0)
329+
@test quat(x) / quat(0, 0, 0, x) == quat(0, 0, 0, -1)
330+
@test quat(0, x, 0, 0) / quat(x, 0, 0, 0) == quat(0, 1, 0, 0)
331+
@test quat(0, x, 0, 0) / quat(0, x, 0, 0) == quat(1, 0, 0, 0)
332+
@test quat(0, x, 0, 0) / quat(0, 0, x, 0) == quat(0, 0, 0, -1)
333+
@test quat(0, x, 0, 0) / quat(0, 0, 0, x) == quat(0, 0, 1, 0)
334+
end
335+
@testset for T in [Float32, Float64]
336+
o = one(T)
337+
z = zero(T)
338+
inf = T(Inf)
339+
nan = T(NaN)
340+
@testset for s in [1, -1], t in [1, -1]
341+
@test isequal(quat(o) / quat(s*inf), quat(s*z, -z, -z, -z))
342+
@test isequal(quat(o) / quat(s*inf, t*o, z, t*z), quat(s*z, -t*z, -z, -t*z))
343+
@test isequal(quat(o) / quat(s*inf, t*nan, t*z, z), quat(s*z, nan, -t*z, -z))
344+
@test isequal(quat(o) / quat(s*inf, t*inf, t*z, z), quat(s*z, -t*z, -t*z, -z))
345+
end
346+
@test isequal(quat(inf) / quat(inf, 1, 2, 3), quat(nan, nan, nan, nan))
347+
@test isequal(quat(inf) / quat(inf, 1, 2, -inf), quat(nan, nan, nan, nan))
348+
end
349+
end
278350
end
279351

280352
@testset "^" begin

0 commit comments

Comments
 (0)