@@ -28,6 +28,7 @@ const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
28
28
const AnyGPUVector{T} = AnyGPUArray{T, 1 }
29
29
const AnyGPUMatrix{T} = AnyGPUArray{T, 2 }
30
30
31
+
31
32
# # broadcasting
32
33
33
34
"""
@@ -38,6 +39,7 @@ this supertype.
38
39
"""
39
40
abstract type AbstractGPUArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
40
41
42
+
41
43
# # scalar iteration
42
44
43
45
export allowscalar, @allowscalar , assertscalar
@@ -46,42 +48,47 @@ export allowscalar, @allowscalar, assertscalar
46
48
47
49
# if the user explicitly calls allowscalar, use that setting for all new tasks
48
50
# XXX : use context variables to inherit the parent task's setting, once available.
49
- const default_scalar_indexing = Ref {Union{Nothing,ScalarIndexing}} (nothing )
51
+ const requested_scalar_indexing = Ref {Union{Nothing,ScalarIndexing}} (nothing )
50
52
51
- """
52
- allowscalar() do
53
- # code that can use scalar indexing
53
+ const _repl_frontend_task = Ref {Union{Nothing,Missing,Task}} ()
54
+ function repl_frontend_task ()
55
+ if ! isassigned (_repl_frontend_task)
56
+ _repl_frontend_task[] = get_repl_frontend_task ()
57
+ end
58
+ _repl_frontend_task[]
59
+ end
60
+ function get_repl_frontend_task ()
61
+ @static if VERSION >= v " 1.10.0-DEV.444" || v " 1.9-beta4" <= VERSION < v " 1.10-"
62
+ if isdefined (Base, :active_repl )
63
+ Base. active_repl. frontend_task
64
+ else
65
+ missing
66
+ end
67
+ else
68
+ nothing
54
69
end
55
-
56
- Denote which operations can use scalar indexing.
57
-
58
- See also: [`@allowscalar`](@ref).
59
- """
60
- function allowscalar (f:: Base.Callable )
61
- task_local_storage (f, :ScalarIndexing , ScalarAllowed)
62
70
end
63
71
64
- """
65
- allowscalar(::Bool)
66
-
67
- Calling this with `false` replaces the default warning about scalar indexing
68
- (show once per session) with an error.
69
-
70
- Instead of calling this with `true`, the preferred style is to allow this locally.
71
- This can be done with the `allowscalar(::Function)` method (with a `do` block)
72
- or with the [`@allowscalar`](@ref) macro.
73
-
74
- Writes to `task_local_storage` for `:ScalarIndexing`. The default is `:ScalarWarn`,
75
- and this function sets `:ScalarAllowed` or `:ScalarDisallowed`.
76
- """
77
- function allowscalar (allow:: Bool = true )
78
- if allow
79
- Base. depwarn (" allowscalar([true]) is deprecated, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations." , :allowscalar )
72
+ @noinline function default_scalar_indexing ()
73
+ if isinteractive ()
74
+ # try to detect the REPL
75
+ repl_task = repl_frontend_task ()
76
+ if repl_task isa Task
77
+ if repl_task === current_task ()
78
+ # we always allow scalar iteration on the REPL's frontend task,
79
+ # where we often trigger scalar indexing by displaying GPU objects.
80
+ ScalarAllowed
81
+ else
82
+ ScalarDisallowed
83
+ end
84
+ else
85
+ # we couldn't detect a REPL in this interactive session, so default to a warning
86
+ ScalarWarn
87
+ end
88
+ else
89
+ # non-interactively, we always disallow scalar iteration
90
+ ScalarDisallowed
80
91
end
81
- setting = allow ? ScalarAllowed : ScalarDisallowed
82
- task_local_storage (:ScalarIndexing , setting)
83
- default_scalar_indexing[] = setting
84
- return
85
92
end
86
93
87
94
"""
90
97
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
91
98
error will be thrown ([`allowscalar`](@ref)).
92
99
"""
93
- function assertscalar (op = " operation" )
94
- val = get! (task_local_storage (), :ScalarIndexing ) do
95
- something (default_scalar_indexing[], isinteractive () ? ScalarWarn : ScalarDisallowed)
100
+ function assertscalar (op:: String )
101
+ behavior = get (task_local_storage (), :ScalarIndexing , nothing )
102
+ if behavior === nothing
103
+ behavior = requested_scalar_indexing[]
104
+ if behavior === nothing
105
+ behavior = default_scalar_indexing ()
106
+ end
107
+ task_local_storage (:ScalarIndexing , behavior)
96
108
end
97
- desc = """ Invocation of $op resulted in scalar indexing of a GPU array.
109
+
110
+ behavior = behavior:: ScalarIndexing
111
+ if behavior === ScalarAllowed
112
+ # fast path
113
+ return
114
+ end
115
+
116
+ _assertscalar (op, behavior)
117
+ end
118
+
119
+ @noinline function _assertscalar (op, behavior)
120
+ desc = """ Invocation of '$op ' resulted in scalar indexing of a GPU array.
98
121
This is typically caused by calling an iterating implementation of a method.
99
122
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
100
- and therefore are only permitted from the REPL for prototyping purposes .
101
- If you did intend to index this array, annotate the caller with @allowscalar. """
102
- if val == ScalarDisallowed
103
- error ( """ Scalar indexing is disallowed.
104
- $desc """ )
105
- elseif val == ScalarWarn
106
- @warn ( """ Performing scalar indexing on task $( current_task ()) .
107
- $desc """ )
123
+ and therefore should be avoided .
124
+
125
+ If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
126
+ to enable scalar iteration globally or for the operations in question. """
127
+ if behavior == ScalarDisallowed
128
+ errorscalar (op)
129
+ elseif behavior == ScalarWarn
130
+ warnscalar (op )
108
131
task_local_storage (:ScalarIndexing , ScalarWarned)
109
132
end
133
+
110
134
return
111
135
end
112
136
137
+ function scalardesc (op)
138
+ desc = """ Invocation of $op resulted in scalar indexing of a GPU array.
139
+ This is typically caused by calling an iterating implementation of a method.
140
+ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
141
+ and therefore should be avoided.
142
+
143
+ If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
144
+ to enable scalar iteration globally or for the operations in question."""
145
+ end
146
+
147
+ @noinline function warnscalar (op)
148
+ desc = scalardesc (op)
149
+ @warn (""" Performing scalar indexing on task $(current_task ()) .
150
+ $desc """ )
151
+ end
152
+
153
+ @noinline function errorscalar (op)
154
+ desc = scalardesc (op)
155
+ error (""" Scalar indexing is disallowed.
156
+ $desc """ )
157
+ end
158
+
113
159
# Like a try-finally block, except without introducing the try scope
114
160
# NOTE: This is deprecated and should not be used from user logic. A proper solution to
115
161
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
@@ -120,6 +166,34 @@ macro __tryfinally(ex, fin)
120
166
)
121
167
end
122
168
169
+ """
170
+ allowscalar([true])
171
+ allowscalar([true]) do
172
+ ...
173
+ end
174
+
175
+ Use this function to allow or disallow scalar indexing, either globall or for the
176
+ duration of the do block.
177
+
178
+ See also: [`@allowscalar`](@ref).
179
+ """
180
+ allowscalar
181
+
182
+ function allowscalar (f:: Base.Callable )
183
+ task_local_storage (f, :ScalarIndexing , ScalarAllowed)
184
+ end
185
+
186
+ function allowscalar (allow:: Bool = true )
187
+ if allow
188
+ @warn """ It's not recommended to use allowscalar([true]) to allow scalar indexing.
189
+ Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog= 1
190
+ end
191
+ setting = allow ? ScalarAllowed : ScalarDisallowed
192
+ task_local_storage (:ScalarIndexing , setting)
193
+ requested_scalar_indexing[] = setting
194
+ return
195
+ end
196
+
123
197
"""
124
198
@allowscalar() begin
125
199
# code that can use scalar indexing
@@ -139,6 +213,9 @@ macro allowscalar(ex)
139
213
end
140
214
end
141
215
216
+
217
+ # # other
218
+
142
219
"""
143
220
backend(T::Type)
144
221
backend(x)
0 commit comments