Skip to content

Commit f8d3762

Browse files
committed
Optimize assertscalar.
1 parent c149fa2 commit f8d3762

File tree

1 file changed

+85
-21
lines changed

1 file changed

+85
-21
lines changed

lib/GPUArraysCore/src/GPUArraysCore.jl

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,48 @@ export allowscalar, @allowscalar, assertscalar
4848

4949
# if the user explicitly calls allowscalar, use that setting for all new tasks
5050
# XXX: use context variables to inherit the parent task's setting, once available.
51-
const default_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)
51+
const requested_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)
52+
53+
const _repl_frontend_task = Ref{Union{Nothing,Missing,Task}}()
54+
function repl_frontend_task()
55+
if !isassigned(_repl_frontend_task)
56+
_repl_frontend_task[] = get_repl_frontend_task()
57+
end
58+
_repl_frontend_task[]
59+
end
60+
function get_repl_frontend_task()
61+
@static if VERSION >= v"1.10.0-DEV.444" || v"1.9-beta4" <= VERSION < v"1.10-"
62+
if isdefined(Base, :active_repl)
63+
Base.active_repl.frontend_task
64+
else
65+
missing
66+
end
67+
else
68+
nothing
69+
end
70+
end
71+
72+
@noinline function default_scalar_indexing()
73+
if isinteractive()
74+
# try to detect the REPL
75+
repl_task = repl_frontend_task()
76+
if repl_task isa Task
77+
if repl_task === current_task()
78+
# we always allow scalar iteration on the REPL's frontend task,
79+
# where we often trigger scalar indexing by displaying GPU objects.
80+
ScalarAllowed
81+
else
82+
ScalarDisallowed
83+
end
84+
else
85+
# we couldn't detect a REPL in this interactive session, so default to a warning
86+
ScalarWarn
87+
end
88+
else
89+
# non-interactively, we always disallow scalar iteration
90+
ScalarDisallowed
91+
end
92+
end
5293

5394
"""
5495
assertscalar(op::String)
@@ -57,41 +98,64 @@ Assert that a certain operation `op` performs scalar indexing. If this is not al
5798
error will be thrown ([`allowscalar`](@ref)).
5899
"""
59100
function assertscalar(op = "operation")
60-
# try to detect the REPL
61-
@static if VERSION >= v"1.10.0-DEV.444" || v"1.9-beta4" <= VERSION < v"1.10-"
62-
if isdefined(Base, :active_repl) && current_task() == Base.active_repl.frontend_task
63-
# we always allow scalar iteration on the REPL's frontend task,
64-
# where we often trigger scalar indexing by displaying GPU objects.
65-
return false
101+
behavior = get(task_local_storage(), :ScalarIndexing, nothing)
102+
if behavior === nothing
103+
behavior = requested_scalar_indexing[]
104+
if behavior === nothing
105+
behavior = default_scalar_indexing()
66106
end
67-
default_behavior = ScalarDisallowed
68-
else
69-
# we can't detect the REPL, but it will only be used in interactive sessions,
70-
# so default to allowing scalar indexing there (but warn).
71-
default_behavior = isinteractive() ? ScalarWarn : ScalarDisallowed
107+
task_local_storage(:ScalarIndexing, behavior)
72108
end
73109

74-
val = get!(task_local_storage(), :ScalarIndexing) do
75-
something(default_scalar_indexing[], default_behavior)
110+
behavior = behavior::ScalarIndexing
111+
if behavior === ScalarAllowed
112+
# fast path
113+
return
76114
end
115+
116+
_assertscalar(op, behavior)
117+
end
118+
119+
@noinline function _assertscalar(op, behavior)
77120
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
78121
This is typically caused by calling an iterating implementation of a method.
79122
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
80123
and therefore should be avoided.
81124
82125
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
83126
to enable scalar iteration globally or for the operations in question."""
84-
if val == ScalarDisallowed
85-
error("""Scalar indexing is disallowed.
86-
$desc""")
87-
elseif val == ScalarWarn
88-
@warn("""Performing scalar indexing on task $(current_task()).
89-
$desc""")
127+
if behavior == ScalarDisallowed
128+
errorscalar(op)
129+
elseif behavior == ScalarWarn
130+
warnscalar(op)
90131
task_local_storage(:ScalarIndexing, ScalarWarned)
91132
end
133+
92134
return
93135
end
94136

137+
function scalardesc(op)
138+
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
139+
This is typically caused by calling an iterating implementation of a method.
140+
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
141+
and therefore should be avoided.
142+
143+
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
144+
to enable scalar iteration globally or for the operations in question."""
145+
end
146+
147+
@noinline function warnscalar(op)
148+
desc = scalardesc(op)
149+
@warn("""Performing scalar indexing on task $(current_task()).
150+
$desc""")
151+
end
152+
153+
@noinline function errorscalar(op)
154+
desc = scalardesc(op)
155+
error("""Scalar indexing is disallowed.
156+
$desc""")
157+
end
158+
95159
# Like a try-finally block, except without introducing the try scope
96160
# NOTE: This is deprecated and should not be used from user logic. A proper solution to
97161
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
@@ -126,7 +190,7 @@ function allowscalar(allow::Bool=true)
126190
end
127191
setting = allow ? ScalarAllowed : ScalarDisallowed
128192
task_local_storage(:ScalarIndexing, setting)
129-
default_scalar_indexing[] = setting
193+
requested_scalar_indexing[] = setting
130194
return
131195
end
132196

0 commit comments

Comments
 (0)