Skip to content

Commit 6a58662

Browse files
authored
Add extension package for StaticArrays (#2273)
1 parent 7a9a0a2 commit 6a58662

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

ext/StaticArraysExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# compatibility with StaticArrays
2+
3+
module StaticArraysExt
4+
5+
using ..CUDA
6+
using ..CUDA: @device_override, @print_and_throw
7+
8+
import StaticArrays
9+
10+
# same quirk as for some Base methods in src/device/quirks.jl
11+
@device_override @noinline StaticArrays.dimension_mismatch_fail(::Type{SA}, a::AbstractArray) where {SA<:StaticArrays.StaticArray} =
12+
@print_and_throw("DimensionMismatch while trying to convert to StaticArray: Expected and actual length of input array differ.")
13+
14+
end # extension module

src/CUDA.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ include("CUDAKernels.jl")
122122
import .CUDAKernels: CUDABackend
123123
export CUDABackend
124124

125+
# StaticArrays is still a direct dependency, so directly include the extension
126+
include("../ext/StaticArraysExt.jl")
127+
125128
include("precompile.jl")
126129

127130
end

test/libraries/staticarrays.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using LinearAlgebra: mul!
2+
using StaticArrays
3+
4+
@testset "StaticArrays" begin
5+
function batched_matvec(ms::CuArray, vs::CuArray)
6+
function matvec_kernel(out, ms, vs, ::Val{N}, ::Val{M}) where {N, M}
7+
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
8+
# Call constructors without @inbounds.
9+
# This asserts that the @device_override
10+
# for StaticArrays.dimension_mismatch_fail() works.
11+
m = SMatrix{N, M, Float32}(@view ms[:, :, i])
12+
v = SVector{M, Float32}(@view vs[:, i])
13+
out[:, i] .= m * v
14+
nothing
15+
end
16+
17+
out = similar(ms, (size(ms, 1), size(ms, 3)))
18+
@cuda threads=size(ms, 3) matvec_kernel(out, ms, vs, Val(size(ms, 1)), Val(size(ms, 2)))
19+
out
20+
end
21+
22+
function batched_matvec(ms, vs)
23+
out = similar(ms, (size(ms, 1), size(ms, 3)))
24+
foreach((o, m, v) -> mul!(o, m, v), eachcol(out), eachslice(ms; dims=3), eachcol(vs))
25+
out
26+
end
27+
28+
ms, vs = randn(Float32, 3, 2, 4), randn(Float32, 2, 4)
29+
@test batched_matvec(ms, vs) Array(batched_matvec(cu(ms), cu(vs)))
30+
end

0 commit comments

Comments
 (0)