Skip to content

Commit 474050e

Browse files
nsajkovchuravy
andauthored
prevent get_backend from overflowing the stack (#602)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent bc646a4 commit 474050e

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/KernelAbstractions.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,13 @@ Get a [`Backend`](@ref) instance suitable for array `A`.
511511
function get_backend end
512512

513513
# Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.:
514-
get_backend(A::AbstractArray) = get_backend(parent(A))
514+
function get_backend(A::AbstractArray)
515+
P = parent(A)
516+
if P isa typeof(A)
517+
throw(ArgumentError("Implement `KernelAbstractions.get_backend(::$(typeof(A)))`"))
518+
end
519+
return get_backend(P)
520+
end
515521

516522
# Define:
517523
# adapt_storage(::Backend, a::Array) = adapt(BackendArray, a)

test/test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Adapt
77

88
identity(x) = x
99

10+
struct UnknownAbstractVector <: AbstractVector{Float32} # issue #588
11+
end
12+
1013
function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; skip_tests = Set{String}())
1114
@conditional_testset "partition" skip_tests begin
1215
backend = Backend()
@@ -80,6 +83,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
8083
@test @inferred(KernelAbstractions.get_backend(view(A, 2:4, 1:3))) isa backendT
8184
@test @inferred(KernelAbstractions.get_backend(Diagonal(x))) isa backendT
8285
@test @inferred(KernelAbstractions.get_backend(Tridiagonal(A))) isa backendT
86+
@test_throws ArgumentError KernelAbstractions.get_backend(UnknownAbstractVector()) # issue #588
8387
end
8488

8589
@conditional_testset "sparse" skip_tests begin

0 commit comments

Comments
 (0)