@@ -117,28 +117,14 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
117
117
end
118
118
119
119
# grid-stride kernel
120
- @kernel function map_kernel (dest, bc, nelem, common_length)
121
-
122
- j = 0
123
- J = @index (Global, Linear)
124
- for i in 1 : nelem
125
- j += 1
126
- if j <= common_length
127
-
128
- J_c = CartesianIndices (axes (bc))[(J- 1 )* nelem + j]
129
- @inbounds dest[J_c] = bc[J_c]
130
- end
131
- end
120
+ @kernel function map_kernel (dest, bc)
121
+ j = @index (Global, Linear)
122
+ @inbounds dest[j] = bc[j]
132
123
end
133
- elements = common_length
134
- elements_per_thread = typemax (Int)
124
+
135
125
kernel = map_kernel (get_backend (dest))
136
- heuristic = launch_heuristic (get_backend (dest), kernel, dest, bc, 1 ,
137
- common_length; elements, elements_per_thread)
138
- config = launch_configuration (get_backend (dest), heuristic;
139
- elements, elements_per_thread)
140
- kernel (dest, bc, config. elements_per_thread,
141
- common_length; ndrange = config. threads)
126
+ config = KernelAbstractions. launch_config (kernel, common_length, nothing )
127
+ kernel (dest, bc; ndrange = config[1 ], workgroupsize = config[2 ])
142
128
143
129
if eltype (dest) <: BrokenBroadcast
144
130
throw (ArgumentError (" Map operation resulting in $(eltype (eltype (dest))) is not GPU compatible" ))
0 commit comments