Skip to content

Commit 8da4ccb

Browse files
committed
adding framework to build off of
1 parent bf138c4 commit 8da4ccb

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ else
319319
const emit_shmem = CUDA._shmem
320320
end
321321

322-
import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize, __size
322+
import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize, __size, __atomic
323323

324324
###
325325
# GPU implementation of shared memory
@@ -381,4 +381,12 @@ end
381381
CUDA.ptx_isa_version(args...)
382382
end
383383

384+
###
385+
# GPU implementation of atomics
386+
###
387+
388+
@inline function Cassette.overdub(::CUDACtx, ::typeof(__atomic), ex)
389+
CUDA.@atomic(ex)
390+
end
391+
384392
end

src/KernelAbstractions.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, @groupsize, @print
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, @groupsize, @print, @atomic
55
export Device, GPU, CPU, Event, MultiEvent, NoneEvent
66
export async_copy!
77

@@ -28,6 +28,7 @@ and then invoked on the arguments.
2828
- [`@uniform`](@ref)
2929
- [`@synchronize`](@ref)
3030
- [`@print`](@ref)
31+
- [`@atomic`](@ref)
3132
3233
# Example:
3334
@@ -306,6 +307,22 @@ macro index(locale, args...)
306307
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
307308
end
308309

310+
"""
311+
@atomic command
312+
313+
This is a unified atomic interface
314+
315+
# Platform differences
316+
- `GPU`: This uses standard `@atomic` calls from CUDA.jl
317+
- `CPU`: This reorganized the command to use atomic pointer logic
318+
"""
319+
320+
macro atomic(ex)
321+
quote
322+
return $__atomic(ex)
323+
end
324+
end
325+
309326
###
310327
# Internal kernel functions
311328
###
@@ -451,6 +468,10 @@ function __synchronize()
451468
error("@synchronize used outside kernel or not captured")
452469
end
453470

471+
function __atomic(ex)
472+
error("@atomic used outside kernel or not captured")
473+
end
474+
454475
@generated function __print(items...)
455476
str = ""
456477
args = []

0 commit comments

Comments
 (0)