@@ -48,7 +48,48 @@ export allowscalar, @allowscalar, assertscalar
48
48
49
49
# if the user explicitly calls allowscalar, use that setting for all new tasks
50
50
# XXX : use context variables to inherit the parent task's setting, once available.
51
- const default_scalar_indexing = Ref {Union{Nothing,ScalarIndexing}} (nothing )
51
+ const requested_scalar_indexing = Ref {Union{Nothing,ScalarIndexing}} (nothing )
52
+
53
+ const _repl_frontend_task = Ref {Union{Nothing,Missing,Task}} ()
54
+ function repl_frontend_task ()
55
+ if ! isassigned (_repl_frontend_task)
56
+ _repl_frontend_task[] = get_repl_frontend_task ()
57
+ end
58
+ _repl_frontend_task[]
59
+ end
60
+ function get_repl_frontend_task ()
61
+ @static if VERSION >= v " 1.10.0-DEV.444" || v " 1.9-beta4" <= VERSION < v " 1.10-"
62
+ if isdefined (Base, :active_repl )
63
+ Base. active_repl. frontend_task
64
+ else
65
+ missing
66
+ end
67
+ else
68
+ nothing
69
+ end
70
+ end
71
+
72
+ @noinline function default_scalar_indexing ()
73
+ if isinteractive ()
74
+ # try to detect the REPL
75
+ repl_task = repl_frontend_task ()
76
+ if repl_task isa Task
77
+ if repl_task === current_task ()
78
+ # we always allow scalar iteration on the REPL's frontend task,
79
+ # where we often trigger scalar indexing by displaying GPU objects.
80
+ ScalarAllowed
81
+ else
82
+ ScalarDisallowed
83
+ end
84
+ else
85
+ # we couldn't detect a REPL in this interactive session, so default to a warning
86
+ ScalarWarn
87
+ end
88
+ else
89
+ # non-interactively, we always disallow scalar iteration
90
+ ScalarDisallowed
91
+ end
92
+ end
52
93
53
94
"""
54
95
assertscalar(op::String)
@@ -57,41 +98,64 @@ Assert that a certain operation `op` performs scalar indexing. If this is not al
57
98
error will be thrown ([`allowscalar`](@ref)).
58
99
"""
59
100
function assertscalar (op = " operation" )
60
- # try to detect the REPL
61
- @static if VERSION >= v " 1.10.0-DEV.444" || v " 1.9-beta4" <= VERSION < v " 1.10-"
62
- if isdefined (Base, :active_repl ) && current_task () == Base. active_repl. frontend_task
63
- # we always allow scalar iteration on the REPL's frontend task,
64
- # where we often trigger scalar indexing by displaying GPU objects.
65
- return false
101
+ behavior = get (task_local_storage (), :ScalarIndexing , nothing )
102
+ if behavior === nothing
103
+ behavior = requested_scalar_indexing[]
104
+ if behavior === nothing
105
+ behavior = default_scalar_indexing ()
66
106
end
67
- default_behavior = ScalarDisallowed
68
- else
69
- # we can't detect the REPL, but it will only be used in interactive sessions,
70
- # so default to allowing scalar indexing there (but warn).
71
- default_behavior = isinteractive () ? ScalarWarn : ScalarDisallowed
107
+ task_local_storage (:ScalarIndexing , behavior)
72
108
end
73
109
74
- val = get! (task_local_storage (), :ScalarIndexing ) do
75
- something (default_scalar_indexing[], default_behavior)
110
+ behavior = behavior:: ScalarIndexing
111
+ if behavior === ScalarAllowed
112
+ # fast path
113
+ return
76
114
end
115
+
116
+ _assertscalar (op, behavior)
117
+ end
118
+
119
+ @noinline function _assertscalar (op, behavior)
77
120
desc = """ Invocation of $op resulted in scalar indexing of a GPU array.
78
121
This is typically caused by calling an iterating implementation of a method.
79
122
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
80
123
and therefore should be avoided.
81
124
82
125
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
83
126
to enable scalar iteration globally or for the operations in question."""
84
- if val == ScalarDisallowed
85
- error (""" Scalar indexing is disallowed.
86
- $desc """ )
87
- elseif val == ScalarWarn
88
- @warn (""" Performing scalar indexing on task $(current_task ()) .
89
- $desc """ )
127
+ if behavior == ScalarDisallowed
128
+ errorscalar (op)
129
+ elseif behavior == ScalarWarn
130
+ warnscalar (op)
90
131
task_local_storage (:ScalarIndexing , ScalarWarned)
91
132
end
133
+
92
134
return
93
135
end
94
136
137
+ function scalardesc (op)
138
+ desc = """ Invocation of $op resulted in scalar indexing of a GPU array.
139
+ This is typically caused by calling an iterating implementation of a method.
140
+ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
141
+ and therefore should be avoided.
142
+
143
+ If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
144
+ to enable scalar iteration globally or for the operations in question."""
145
+ end
146
+
147
+ @noinline function warnscalar (op)
148
+ desc = scalardesc (op)
149
+ @warn (""" Performing scalar indexing on task $(current_task ()) .
150
+ $desc """ )
151
+ end
152
+
153
+ @noinline function errorscalar (op)
154
+ desc = scalardesc (op)
155
+ error (""" Scalar indexing is disallowed.
156
+ $desc """ )
157
+ end
158
+
95
159
# Like a try-finally block, except without introducing the try scope
96
160
# NOTE: This is deprecated and should not be used from user logic. A proper solution to
97
161
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
@@ -126,7 +190,7 @@ function allowscalar(allow::Bool=true)
126
190
end
127
191
setting = allow ? ScalarAllowed : ScalarDisallowed
128
192
task_local_storage (:ScalarIndexing , setting)
129
- default_scalar_indexing [] = setting
193
+ requested_scalar_indexing [] = setting
130
194
return
131
195
end
132
196
0 commit comments