From 7409f4a52cf4aee970e8cdd4d05ad1ee2162c646 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 30 Jun 2025 10:44:34 +0200 Subject: [PATCH] add get_backend for StaticArrays (cherry picked from commit b8b53da42ca83448f77427e07c242718cf364b3c) --- Project.toml | 4 +++- ext/StaticArraysExt.jl | 9 +++++++++ src/KernelAbstractions.jl | 1 - test/test.jl | 13 +++++++++++++ 4 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 ext/StaticArraysExt.jl diff --git a/Project.toml b/Project.toml index 775be0e79..243140020 100644 --- a/Project.toml +++ b/Project.toml @@ -16,11 +16,11 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c" SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd" [weakdeps] +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -29,6 +29,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" EnzymeExt = "EnzymeCore" LinearAlgebraExt = "LinearAlgebra" SparseArraysExt = "SparseArrays" +StaticArraysExt = "StaticArrays" [compat] Adapt = "0.4, 1.0, 2.0, 3.0, 4" @@ -50,6 +51,7 @@ julia = "1.10" pocl_jll = "7" [extras] +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/ext/StaticArraysExt.jl b/ext/StaticArraysExt.jl new file mode 100644 index 000000000..8cdcdd308 --- /dev/null +++ b/ext/StaticArraysExt.jl @@ -0,0 +1,9 @@ +module StaticArraysExt + +import KernelAbstractions: get_backend, CPU +using StaticArrays: SizedArray, MArray + +get_backend(A::SizedArray) = get_backend(A.data) +get_backend(::MArray) = CPU() + +end diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 15757e3a2..43404d9b4 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -12,7 +12,6 @@ import PrecompileTools import Atomix: @atomic, @atomicswap, @atomicreplace using MacroTools -using StaticArrays using Adapt """ diff --git a/test/test.jl b/test/test.jl index 640b32c0d..29cedb43d 100644 --- a/test/test.jl +++ b/test/test.jl @@ -4,6 +4,7 @@ using InteractiveUtils using LinearAlgebra using SparseArrays using Adapt +using StaticArrays identity(x) = x @@ -95,6 +96,18 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk @test @inferred(KernelAbstractions.get_backend(sparse(A))) isa backendT end + @conditional_testset "StaticArrays" skip_tests begin + backend = Backend() + backendT = typeof(backend).name.wrapper # To look through CUDABackend{true, false} + @test backend isa backendT + + @test KernelAbstractions.get_backend(@MMatrix [1.0]) isa CPU + @test_throws ArgumentError KernelAbstractions.get_backend(@SMatrix [1.0]) + + A = allocate(backend, Float32, 5, 5) + @test @inferred(KernelAbstractions.get_backend(SizedMatrix{5, 5}(A))) isa backendT + end + @conditional_testset "adapt" skip_tests begin backend = Backend() x = allocate(backend, Float32, 5)