Skip to content

Commit d336cfb

Browse files
authored
specialize mul_fast and add_fast to allow FMA (#988)
* specialize mul_fast and add_fast to allow FMA * bump patch version number
1 parent 4c5e34e commit d336cfb

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.3.2"
3+
version = "1.3.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/linalg.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ end
5050
@inline Base.muladd(scalar::Number, a::StaticArray, b::StaticArray) = map((ai, bi) -> muladd(scalar, ai, bi), a, b)
5151
@inline Base.muladd(a::StaticArray, scalar::Number, b::StaticArray) = map((ai, bi) -> muladd(ai, scalar, bi), a, b)
5252

53+
54+
# @fastmath operators
55+
@inline Base.FastMath.mul_fast(a::Number, b::StaticArray) = map(c -> Base.FastMath.mul_fast(a, c), b)
56+
@inline Base.FastMath.mul_fast(a::StaticArray, b::Number) = map(c -> Base.FastMath.mul_fast(c, b), a)
57+
58+
@inline Base.FastMath.add_fast(a::StaticArray, b::StaticArray) = map(Base.FastMath.add_fast, a, b)
59+
@inline Base.FastMath.sub_fast(a::StaticArray, b::StaticArray) = map(Base.FastMath.sub_fast, a, b)
60+
61+
5362
#--------------------------------------------------
5463
# Matrix algebra
5564

test/linalg.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,31 @@ StaticArrays.similar_type(::Union{RotMat2,Type{RotMat2}}) = SMatrix{2,2,Float64,
6969
end
7070
end
7171

72+
@testset "@fastmath operators" begin
73+
for T in (Int, Float32, Float64)
74+
s0 = convert(T, 2)
75+
v1 = @SVector T[2, 4, 6, 8]
76+
v2 = @SVector T[4, 3, 2, 1]
77+
m1 = @SMatrix T[2 4; 6 8]
78+
m2 = @SMatrix T[4 3; 2 1]
79+
80+
# Use that these small integers can be represetnted exactly
81+
# as floating point numbers. In general, the comparison of
82+
# floats should use `≈` instead of `===`.
83+
# These should be turned into `vfmadd...` calls
84+
@test @fastmath(@inferred(s0 * v1 + v2)) === @SVector T[8, 11, 14, 17]
85+
@test @fastmath(@inferred(v1 * s0 + v2)) === @SVector T[8, 11, 14, 17]
86+
@test @fastmath(@inferred(s0 * m1 + m2)) === @SMatrix T[8 11; 14 17]
87+
@test @fastmath(@inferred(m1 * s0 + m2)) === @SMatrix T[8 11; 14 17]
88+
89+
# These should be turned into `vfmsub...` calls
90+
@test @fastmath(@inferred(s0 * v1 - v2)) === @SVector T[0, 5, 10, 15]
91+
@test @fastmath(@inferred(v1 * s0 - v2)) === @SVector T[0, 5, 10, 15]
92+
@test @fastmath(@inferred(s0 * m1 - m2)) === @SMatrix T[0 5; 10 15]
93+
@test @fastmath(@inferred(m1 * s0 - m2)) === @SMatrix T[0 5; 10 15]
94+
end
95+
end
96+
7297
@testset "Interaction with `UniformScaling`" begin
7398
@test @inferred(@SMatrix([0 1; 2 3]) + I) === @SMatrix [1 1; 2 4]
7499
@test @inferred(I + @SMatrix([0 1; 2 3])) === @SMatrix [1 1; 2 4]

0 commit comments

Comments
 (0)