1
1
module CUDAKernels
2
2
3
3
import CUDA
4
- import SpecialFunctions
5
4
import StaticArrays
6
5
import StaticArrays: MArray
7
- import Cassette
8
6
import Adapt
9
7
import KernelAbstractions
10
8
@@ -191,7 +189,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
191
189
ndrange, workgroupsize, iterspace, dynamic = launch_config (obj, ndrange, workgroupsize)
192
190
# this might not be the final context, since we may tune the workgroupsize
193
191
ctx = mkcontext (obj, ndrange, iterspace)
194
- kernel = CUDA. @cuda launch= false name = String ( nameof ( obj. f)) Cassette . overdub (CUDACTX, obj . f, ctx, args... )
192
+ kernel = CUDA. @cuda launch= false obj. f ( ctx, args... )
195
193
196
194
# figure out the optimal workgroupsize automatically
197
195
if KernelAbstractions. workgroupsize (obj) <: DynamicSize && workgroupsize === nothing
@@ -220,52 +218,49 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
220
218
221
219
# Launch kernel
222
220
event = CUDA. CuEvent (CUDA. EVENT_DISABLE_TIMING)
223
- kernel (CUDACTX, obj . f, ctx, args... ; threads= threads, blocks= nblocks, stream= stream)
221
+ kernel (ctx, args... ; threads= threads, blocks= nblocks, stream= stream)
224
222
225
223
CUDA. record (event, stream)
226
224
return CudaEvent (event)
227
225
end
228
226
229
- Cassette . @context CUDACtx
227
+ import CUDA : @device_override
230
228
231
229
import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearIndices
232
230
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
233
231
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
234
232
235
- const CUDACTX = Cassette. disablehooks (CUDACtx (pass = CompilerPass))
236
- KernelAbstractions. cassette (:: Kernel{CUDADevice} ) = CUDACTX
237
-
238
233
function mkcontext (kernel:: Kernel{CUDADevice} , _ndrange, iterspace)
239
234
CompilerMetadata {KernelAbstractions.ndrange(kernel), DynamicCheck} (_ndrange, iterspace)
240
235
end
241
236
242
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Local_Linear), ctx)
237
+ @device_override @ inline function __index_Local_Linear ( ctx)
243
238
return CUDA. threadIdx (). x
244
239
end
245
240
246
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Group_Linear), ctx)
241
+ @device_override @ inline function __index_Group_Linear ( ctx)
247
242
return CUDA. blockIdx (). x
248
243
end
249
244
250
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Global_Linear), ctx)
245
+ @device_override @ inline function __index_Global_Linear ( ctx)
251
246
I = @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
252
247
# TODO : This is unfortunate, can we get the linear index cheaper
253
248
@inbounds LinearIndices (__ndrange (ctx))[I]
254
249
end
255
250
256
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Local_Cartesian), ctx)
251
+ @device_override @ inline function __index_Local_Cartesian ( ctx)
257
252
@inbounds workitems (__iterspace (ctx))[CUDA. threadIdx (). x]
258
253
end
259
254
260
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Group_Cartesian), ctx)
255
+ @device_override @ inline function __index_Group_Cartesian ( ctx)
261
256
@inbounds blocks (__iterspace (ctx))[CUDA. blockIdx (). x]
262
257
end
263
258
264
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Global_Cartesian), ctx)
259
+ @device_override @ inline function __index_Global_Cartesian ( ctx)
265
260
return @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
266
261
end
267
262
268
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __validindex), ctx)
263
+ @device_override @ inline function __validindex ( ctx)
269
264
if __dynamic_checkbounds (ctx)
270
265
I = @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
271
266
return I in __ndrange (ctx)
276
271
277
272
import KernelAbstractions: groupsize, __groupsize, __workitems_iterspace, add_float_contract, sub_float_contract, mul_float_contract
278
273
279
- KernelAbstractions. generate_overdubs (@__MODULE__ , CUDACtx)
280
-
281
- # ##
282
- # CUDA specific method rewrites
283
- # ##
284
-
285
- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Float64 ) = ^ (x, y)
286
- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Float32 ) = ^ (x, y)
287
- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Int32 ) = ^ (x, y)
288
- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Int32 ) = ^ (x, y)
289
- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Union{Float32, Float64} , y:: Int64 ) = ^ (x, y)
290
-
291
- # libdevice.jl
292
- const cudafuns = (:cos , :cospi , :sin , :sinpi , :tan ,
293
- :acos , :asin , :atan ,
294
- :cosh , :sinh , :tanh ,
295
- :acosh , :asinh , :atanh ,
296
- :log , :log10 , :log1p , :log2 ,
297
- :exp , :exp2 , :exp10 , :expm1 , :ldexp ,
298
- # :isfinite, :isinf, :isnan, :signbit,
299
- :abs ,
300
- :sqrt , :cbrt ,
301
- :ceil , :floor ,)
302
- for f in cudafuns
303
- @eval function Cassette. overdub (ctx:: CUDACtx , :: typeof (Base.$ f), x:: Union{Float32, Float64} )
304
- @Base . _inline_meta
305
- return Base.$ f (x)
306
- end
307
- end
308
-
309
- @inline Cassette. overdub (:: CUDACtx , :: typeof (sincos), x:: Union{Float32, Float64} ) = (Base. sin (x), Base. cos (x))
310
- @inline Cassette. overdub (:: CUDACtx , :: typeof (exp), x:: Union{ComplexF32, ComplexF64} ) = Base. exp (x)
311
-
312
- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. gamma), x:: Union{Float32, Float64} ) = CUDA. tgamma (x)
313
- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erf), x:: Union{Float32, Float64} ) = SpecialFunctions. erf (x)
314
- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erfc), x:: Union{Float32, Float64} ) = SpecialFunctions. erfc (x)
315
-
316
274
@static if Base. isbindingresolved (CUDA, :emit_shmem ) && Base. isdefined (CUDA, :emit_shmem )
317
275
const emit_shmem = CUDA. emit_shmem
318
276
else
@@ -325,7 +283,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
325
283
# GPU implementation of shared memory
326
284
# ##
327
285
328
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( SharedMemory), :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
286
+ @device_override @ inline function SharedMemory ( :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
329
287
ptr = emit_shmem (T, Val (prod (Dims)))
330
288
CUDA. CuDeviceArray (Dims, ptr)
331
289
end
@@ -335,15 +293,15 @@ end
335
293
# - private memory for each workitem
336
294
# ##
337
295
338
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( Scratchpad), ctx, :: Type{T} , :: Val{Dims} ) where {T, Dims}
296
+ @device_override @ inline function Scratchpad ( ctx, :: Type{T} , :: Val{Dims} ) where {T, Dims}
339
297
MArray {__size(Dims), T} (undef)
340
298
end
341
299
342
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __synchronize) )
300
+ @device_override @ inline function __synchronize ( )
343
301
CUDA. sync_threads ()
344
302
end
345
303
346
- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __print), args... )
304
+ @device_override @ inline function __print ( args... )
347
305
CUDA. _cuprint (args... )
348
306
end
349
307
@@ -356,29 +314,4 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
356
314
# Argument conversion
357
315
KernelAbstractions. argconvert (k:: Kernel{CUDADevice} , arg) = CUDA. cudaconvert (arg)
358
316
359
- # Cassette.jl#195
360
- # Device intrinsics are inferred in a different World (1.6) or using MethodOverlay tables (1.7)
361
- # Cassette sees neither of them and thus overdubbing them fails.
362
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. arrayref), args... )
363
- CUDA. arrayref (args... )
364
- end
365
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. arrayset), args... )
366
- CUDA. arrayset (args... )
367
- end
368
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. const_arrayref), args... )
369
- CUDA. const_arrayref (args... )
370
- end
371
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. logb), args... )
372
- CUDA. logb (args... )
373
- end
374
- # @inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.tgamma), args...)
375
- # CUDA.tgamma(args...)
376
- # end
377
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. compute_capability), args... )
378
- CUDA. compute_capability (args... )
379
- end
380
- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. ptx_isa_version), args... )
381
- CUDA. ptx_isa_version (args... )
382
- end
383
-
384
317
end
0 commit comments