Skip to content

Commit 28cf68c

Browse files
nsajkovchuravy
andcommitted
prevent get_backend from overflowing the stack (#602)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com> (cherry picked from commit 474050e)
1 parent ccb0211 commit 28cf68c

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/KernelAbstractions.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,13 @@ Get a [`Backend`](@ref) instance suitable for array `A`.
516516
function get_backend end
517517

518518
# Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.:
519-
get_backend(A::AbstractArray) = get_backend(parent(A))
519+
function get_backend(A::AbstractArray)
520+
P = parent(A)
521+
if P isa typeof(A)
522+
throw(ArgumentError("Implement `KernelAbstractions.get_backend(::$(typeof(A)))`"))
523+
end
524+
return get_backend(P)
525+
end
520526

521527
get_backend(::Array) = CPU()
522528

test/test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Adapt
77

88
identity(x) = x
99

10+
struct UnknownAbstractVector <: AbstractVector{Float32} # issue #588
11+
end
12+
1013
function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; skip_tests = Set{String}())
1114
@conditional_testset "partition" skip_tests begin
1215
backend = Backend()
@@ -80,6 +83,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
8083
@test @inferred(KernelAbstractions.get_backend(view(A, 2:4, 1:3))) isa backendT
8184
@test @inferred(KernelAbstractions.get_backend(Diagonal(x))) isa backendT
8285
@test @inferred(KernelAbstractions.get_backend(Tridiagonal(A))) isa backendT
86+
@test_throws ArgumentError KernelAbstractions.get_backend(UnknownAbstractVector()) # issue #588
8387
end
8488

8589
@conditional_testset "sparse" skip_tests begin

0 commit comments

Comments
 (0)