Skip to content

Commit 25efed2

Browse files
committed
Switch to SPIR-V address spaces.
1 parent ea19986 commit 25efed2

File tree

10 files changed

+34
-30
lines changed

10 files changed

+34
-30
lines changed

lib/intrinsics/src/atomic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
const atomic_integer_types = [UInt32, Int32]
99
# TODO: 64-bit atomics with ZE_DEVICE_MODULE_FLAG_INT64_ATOMICS
1010
# TODO: additional floating-point atomics with ZE_extension_float_atomics
11-
const atomic_memory_types = [AS.Local, AS.Global]
11+
const atomic_memory_types = [AS.Workgroup, AS.CrossWorkgroup]
1212

1313

1414
# generically typed

lib/intrinsics/src/math.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ for gentype in generic_types
9494
cosval = Ref{$gentype}()
9595
sinval = GC.@preserve cosval begin
9696
ptr = Base.unsafe_convert(Ptr{$gentype}, cosval)
97-
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Private}, ptr)
98-
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Private}), x, llvm_ptr)
97+
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Function}, ptr)
98+
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Function}), x, llvm_ptr)
9999
end
100100
return sinval, cosval[]
101101
end

lib/intrinsics/src/memory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
Context() do ctx
66
# XXX: as long as LLVMPtr is emitted as i8*, it doesn't make sense to type the GV
77
eltyp = convert(LLVMType, LLVM.Int8Type())
8-
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Local})
8+
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Workgroup})
99

1010
# create a function
1111
llvm_f, _ = create_function(T_ptr)
1212

1313
# create the global variable
1414
mod = LLVM.parent(llvm_f)
1515
gv_typ = LLVM.ArrayType(eltyp, len * sizeof(T))
16-
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Local)
16+
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Workgroup)
1717
if len > 0
1818
linkage!(gv, LLVM.API.LLVMInternalLinkage)
1919
initializer!(gv, null(gv_typ))
@@ -33,6 +33,6 @@
3333
ret!(builder, untyped_ptr)
3434
end
3535

36-
call_function(llvm_f, LLVMPtr{T,AS.Local})
36+
call_function(llvm_f, LLVMPtr{T,AS.Workgroup})
3737
end
3838
end

lib/intrinsics/src/pointer.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ export AS
44

55
module AS
66

7-
const Private = 0
8-
const Global = 1
9-
const Constant = 2
10-
const Local = 3
11-
const Generic = 4
12-
const Input = 5
13-
const Output = 6
14-
const Count = 7
7+
const Function = 0
8+
const CrossWorkgroup = 1
9+
const UniformConstant = 2
10+
const Workgroup = 3
11+
const Generic = 4
12+
const DeviceOnlyINTEL = 5 # XXX: should be CrossWorkgroup
13+
const HostOnlyINTEL = 6 # when USM is not supported
14+
const Input = 7
15+
const Output = 8
16+
const CodeSectionINTEL = 9
17+
const Private = 10
1518

1619
end

lib/intrinsics/src/printf.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
Context() do ctx
3232
T_void = LLVM.VoidType()
3333
T_int32 = LLVM.Int32Type()
34-
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.Constant)
34+
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.UniformConstant)
3535

3636
# create functions
3737
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
@@ -43,7 +43,7 @@ end
4343
entry = BasicBlock(llvm_f, "entry")
4444
position!(builder, entry)
4545

46-
str = globalstring_ptr!(builder, String(fmt); addrspace=AS.Constant)
46+
str = globalstring_ptr!(builder, String(fmt); addrspace=AS.UniformConstant)
4747

4848
# invoke printf and return
4949
printf_typ = LLVM.FunctionType(T_int32, [T_pint8]; vararg=true)

lib/intrinsics/src/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ macro builtin_ccall(name, ret, argtypes, args...)
3434
elt, as = T.parameters
3535

3636
# mangle address space
37-
ASstr = if as == AS.Global
37+
ASstr = if as == AS.CrossWorkgroup
3838
"CLglobal"
3939
#elseif as == AS.Global_device
4040
# "CLdevice"
4141
#elseif as == AS.Global_host
4242
# "CLhost"
43-
elseif as == AS.Local
43+
elseif as == AS.Workgroup
4444
"CLlocal"
45-
elseif as == AS.Constant
45+
elseif as == AS.UniformConstant
4646
"CLconstant"
47-
elseif as == AS.Private
47+
elseif as == AS.Function
4848
"CLprivate"
4949
elseif as == AS.Generic
5050
"CLgeneric"
5151
else
52-
error("Unknown address space $AS")
52+
error("Unknown address space $as")
5353
end
5454

5555
# encode as vendor qualifier

lib/intrinsics/src/work_item.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ for (julia_name, (spirv_name, offset)) in [
2121
export $julia_name
2222
@device_function $julia_name() =
2323
Base.llvmcall(
24-
$("""$gvar_name = external addrspace(1) global i32
24+
$("""$gvar_name = external addrspace($(AS.Input)) global i32
2525
define i32 @entry() #0 {
26-
%val = load i32, i32 addrspace(1)* $gvar_name
26+
%val = load i32, i32 addrspace($(AS.Input))* $gvar_name
2727
ret i32 %val
2828
}
2929
attributes #0 = { alwaysinline }
@@ -48,9 +48,9 @@ for (julia_name, (spirv_name, offset)) in [
4848
export $julia_name
4949
@device_function $julia_name(dimindx::Integer=1) =
5050
Base.llvmcall(
51-
$("""$gvar_name = external addrspace(1) global <3 x i32>
51+
$("""$gvar_name = external addrspace($(AS.Input)) global <3 x i32>
5252
define i32 @entry(i32 %idx) #0 {
53-
%val = load <3 x i32>, <3 x i32> addrspace(1)* $gvar_name
53+
%val = load <3 x i32>, <3 x i32> addrspace($(AS.Input))* $gvar_name
5454
%element = extractelement <3 x i32> %val, i32 %idx
5555
ret i32 %element
5656
}

src/array.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,10 @@ end
299299

300300
## interop with GPU arrays
301301

302-
function Base.unsafe_convert(::Type{CLDeviceArray{T, N, AS.Global}}, a::CLArray{T, N}) where {T, N}
303-
return CLDeviceArray{T, N, AS.Global}(
304-
size(a), reinterpret(LLVMPtr{T, AS.Global}, pointer(a)),
302+
function Base.unsafe_convert(::Type{CLDeviceArray{T, N, AS.CrossWorkgroup}},
303+
a::CLArray{T, N}) where {T, N}
304+
return CLDeviceArray{T, N, AS.CrossWorkgroup}(
305+
size(a), reinterpret(LLVMPtr{T, AS.CrossWorkgroup}, pointer(a)),
305306
a.maxsize - a.offset * Base.elsize(a)
306307
)
307308
end

src/compiler/execution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function Adapt.adapt_storage(to::KernelAdaptor, buf::cl.AbstractMemory)
9292
end
9393
function Adapt.adapt_storage(to::KernelAdaptor, arr::CLArray{T, N}) where {T, N}
9494
push!(to.indirect_memory, arr.data[].mem)
95-
return Base.unsafe_convert(CLDeviceArray{T, N, AS.Global}, arr)
95+
return Base.unsafe_convert(CLDeviceArray{T, N, AS.CrossWorkgroup}, arr)
9696
end
9797

9898
# Base.RefValue isn't GPU compatible, so provide a compatible alternative

test/execution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ end
9393

9494
@test OpenCL.return_type(identity, Tuple{Int}) === Int
9595
@test OpenCL.return_type(sin, Tuple{Float32}) === Float32
96-
@test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32,1,AS.Global},Int32}) === Float32
96+
@test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32,1,AS.CrossWorkgroup},Int32}) === Float32
9797
@test OpenCL.return_type(getindex, Tuple{Base.RefValue{Integer}}) === Integer
9898
end
9999

0 commit comments

Comments
 (0)