2
2
3
3
export @cuStaticSharedMem , @cuDynamicSharedMem
4
4
5
- shmem_id = 0
6
-
7
5
"""
8
6
@cuStaticSharedMem(T::Type, dims) -> CuDeviceArray{T,AS.Shared}
9
7
@@ -13,15 +11,9 @@ inferable and the dimensions should be constant, or an error will be thrown and
13
11
generator function will be called dynamically.
14
12
"""
15
13
macro cuStaticSharedMem (T, dims)
16
- # FIXME : generating a unique id in the macro is incorrect, as multiple parametrically typed
17
- # functions will alias the id (and the size might be a parameter). but incrementing in
18
- # the @generated function doesn't work, as it is supposed to be pure and identical
19
- # invocations will erroneously share (and even cause multiple shmem globals).
20
- id = gensym (" static_shmem" )
21
-
22
14
quote
23
15
len = prod ($ (esc (dims)))
24
- ptr = emit_shmem (Val ( $ ( QuoteNode (id))), $ (esc (T)), Val (len))
16
+ ptr = emit_shmem ($ (esc (T)), Val (len))
25
17
CuDeviceArray ($ (esc (dims)), ptr)
26
18
end
27
19
end
@@ -40,19 +32,16 @@ pointer can be specified. This is useful when dealing with a heterogeneous buffe
40
32
shared memory; in the case of a homogeneous multi-part buffer it is preferred to use `view`.
41
33
"""
42
34
macro cuDynamicSharedMem (T, dims, offset= 0 )
43
- id = gensym (" dynamic_shmem" )
44
-
45
35
# TODO : boundscheck against %dynamic_smem_size (currently unsupported by LLVM)
46
-
47
36
quote
48
37
len = prod ($ (esc (dims)))
49
- ptr = emit_shmem (Val ( $ ( QuoteNode (id))), $ (esc (T))) + $ (esc (offset))
38
+ ptr = emit_shmem ($ (esc (T))) + $ (esc (offset))
50
39
CuDeviceArray ($ (esc (dims)), ptr)
51
40
end
52
41
end
53
42
54
43
# get a pointer to shared memory, with known (static) or zero length (dynamic shared memory)
55
- @generated function emit_shmem (:: Val{id} , :: Type{T} , :: Val{len} = Val (0 )) where {id, T,len}
44
+ @generated function emit_shmem (:: Type{T} , :: Val{len} = Val (0 )) where {T,len}
56
45
Context () do ctx
57
46
eltyp = convert (LLVMType, T; ctx)
58
47
T_ptr = convert (LLVMType, LLVMPtr{T,AS. Shared}; ctx)
63
52
# create the global variable
64
53
mod = LLVM. parent (llvm_f)
65
54
gv_typ = LLVM. ArrayType (eltyp, len)
66
- gv = GlobalVariable (mod, gv_typ, GPUCompiler . safe_name ( string (id)) , AS. Shared)
55
+ gv = GlobalVariable (mod, gv_typ, " shmem " , AS. Shared)
67
56
if len > 0
68
57
# static shared memory should be demoted to local variables, whenever possible.
69
58
# this is done by the NVPTX ASM printer:
0 commit comments