Skip to content

Commit c628dda

Browse files
authored
Merge pull request #450 from JuliaGPU/tb/repl_allowscalar
Allow scalar iteration in the REPL on 1.9+.
2 parents c015cca + 7e8e6df commit c628dda

File tree

1 file changed

+120
-43
lines changed

1 file changed

+120
-43
lines changed

lib/GPUArraysCore/src/GPUArraysCore.jl

Lines changed: 120 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
2828
const AnyGPUVector{T} = AnyGPUArray{T, 1}
2929
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
3030

31+
3132
## broadcasting
3233

3334
"""
@@ -38,6 +39,7 @@ this supertype.
3839
"""
3940
abstract type AbstractGPUArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
4041

42+
4143
## scalar iteration
4244

4345
export allowscalar, @allowscalar, assertscalar
@@ -46,42 +48,47 @@ export allowscalar, @allowscalar, assertscalar
4648

4749
# if the user explicitly calls allowscalar, use that setting for all new tasks
4850
# XXX: use context variables to inherit the parent task's setting, once available.
49-
const default_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)
51+
const requested_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)
5052

51-
"""
52-
allowscalar() do
53-
# code that can use scalar indexing
53+
const _repl_frontend_task = Ref{Union{Nothing,Missing,Task}}()
54+
function repl_frontend_task()
55+
if !isassigned(_repl_frontend_task)
56+
_repl_frontend_task[] = get_repl_frontend_task()
57+
end
58+
_repl_frontend_task[]
59+
end
60+
function get_repl_frontend_task()
61+
@static if VERSION >= v"1.10.0-DEV.444" || v"1.9-beta4" <= VERSION < v"1.10-"
62+
if isdefined(Base, :active_repl)
63+
Base.active_repl.frontend_task
64+
else
65+
missing
66+
end
67+
else
68+
nothing
5469
end
55-
56-
Denote which operations can use scalar indexing.
57-
58-
See also: [`@allowscalar`](@ref).
59-
"""
60-
function allowscalar(f::Base.Callable)
61-
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
6270
end
6371

64-
"""
65-
allowscalar(::Bool)
66-
67-
Calling this with `false` replaces the default warning about scalar indexing
68-
(show once per session) with an error.
69-
70-
Instead of calling this with `true`, the preferred style is to allow this locally.
71-
This can be done with the `allowscalar(::Function)` method (with a `do` block)
72-
or with the [`@allowscalar`](@ref) macro.
73-
74-
Writes to `task_local_storage` for `:ScalarIndexing`. The default is `:ScalarWarn`,
75-
and this function sets `:ScalarAllowed` or `:ScalarDisallowed`.
76-
"""
77-
function allowscalar(allow::Bool=true)
78-
if allow
79-
Base.depwarn("allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.", :allowscalar)
72+
@noinline function default_scalar_indexing()
73+
if isinteractive()
74+
# try to detect the REPL
75+
repl_task = repl_frontend_task()
76+
if repl_task isa Task
77+
if repl_task === current_task()
78+
# we always allow scalar iteration on the REPL's frontend task,
79+
# where we often trigger scalar indexing by displaying GPU objects.
80+
ScalarAllowed
81+
else
82+
ScalarDisallowed
83+
end
84+
else
85+
# we couldn't detect a REPL in this interactive session, so default to a warning
86+
ScalarWarn
87+
end
88+
else
89+
# non-interactively, we always disallow scalar iteration
90+
ScalarDisallowed
8091
end
81-
setting = allow ? ScalarAllowed : ScalarDisallowed
82-
task_local_storage(:ScalarIndexing, setting)
83-
default_scalar_indexing[] = setting
84-
return
8592
end
8693

8794
"""
@@ -90,26 +97,65 @@ end
9097
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
9198
error will be thrown ([`allowscalar`](@ref)).
9299
"""
93-
function assertscalar(op = "operation")
94-
val = get!(task_local_storage(), :ScalarIndexing) do
95-
something(default_scalar_indexing[], isinteractive() ? ScalarWarn : ScalarDisallowed)
100+
function assertscalar(op::String)
101+
behavior = get(task_local_storage(), :ScalarIndexing, nothing)
102+
if behavior === nothing
103+
behavior = requested_scalar_indexing[]
104+
if behavior === nothing
105+
behavior = default_scalar_indexing()
106+
end
107+
task_local_storage(:ScalarIndexing, behavior)
96108
end
97-
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
109+
110+
behavior = behavior::ScalarIndexing
111+
if behavior === ScalarAllowed
112+
# fast path
113+
return
114+
end
115+
116+
_assertscalar(op, behavior)
117+
end
118+
119+
@noinline function _assertscalar(op, behavior)
120+
desc = """Invocation of '$op' resulted in scalar indexing of a GPU array.
98121
This is typically caused by calling an iterating implementation of a method.
99122
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
100-
and therefore are only permitted from the REPL for prototyping purposes.
101-
If you did intend to index this array, annotate the caller with @allowscalar."""
102-
if val == ScalarDisallowed
103-
error("""Scalar indexing is disallowed.
104-
$desc""")
105-
elseif val == ScalarWarn
106-
@warn("""Performing scalar indexing on task $(current_task()).
107-
$desc""")
123+
and therefore should be avoided.
124+
125+
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
126+
to enable scalar iteration globally or for the operations in question."""
127+
if behavior == ScalarDisallowed
128+
errorscalar(op)
129+
elseif behavior == ScalarWarn
130+
warnscalar(op)
108131
task_local_storage(:ScalarIndexing, ScalarWarned)
109132
end
133+
110134
return
111135
end
112136

137+
function scalardesc(op)
138+
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
139+
This is typically caused by calling an iterating implementation of a method.
140+
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
141+
and therefore should be avoided.
142+
143+
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
144+
to enable scalar iteration globally or for the operations in question."""
145+
end
146+
147+
@noinline function warnscalar(op)
148+
desc = scalardesc(op)
149+
@warn("""Performing scalar indexing on task $(current_task()).
150+
$desc""")
151+
end
152+
153+
@noinline function errorscalar(op)
154+
desc = scalardesc(op)
155+
error("""Scalar indexing is disallowed.
156+
$desc""")
157+
end
158+
113159
# Like a try-finally block, except without introducing the try scope
114160
# NOTE: This is deprecated and should not be used from user logic. A proper solution to
115161
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
@@ -120,6 +166,34 @@ macro __tryfinally(ex, fin)
120166
)
121167
end
122168

169+
"""
170+
allowscalar([true])
171+
allowscalar([true]) do
172+
...
173+
end
174+
175+
Use this function to allow or disallow scalar indexing, either globall or for the
176+
duration of the do block.
177+
178+
See also: [`@allowscalar`](@ref).
179+
"""
180+
allowscalar
181+
182+
function allowscalar(f::Base.Callable)
183+
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
184+
end
185+
186+
function allowscalar(allow::Bool=true)
187+
if allow
188+
@warn """It's not recommended to use allowscalar([true]) to allow scalar indexing.
189+
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog=1
190+
end
191+
setting = allow ? ScalarAllowed : ScalarDisallowed
192+
task_local_storage(:ScalarIndexing, setting)
193+
requested_scalar_indexing[] = setting
194+
return
195+
end
196+
123197
"""
124198
@allowscalar() begin
125199
# code that can use scalar indexing
@@ -139,6 +213,9 @@ macro allowscalar(ex)
139213
end
140214
end
141215

216+
217+
## other
218+
142219
"""
143220
backend(T::Type)
144221
backend(x)

0 commit comments

Comments
 (0)