Skip to content

Commit ad00c10

Browse files
authored
Switch to the LLVM SPIR-V back-end. (#285)
Also makes SPIRVIntrinsics contain actual SPIR-V intrinsics. Previously, they were mostly OpenCL intrinsics relying on the Khronos translator to map them to SPIR-V ones.
1 parent 9247b80 commit ad00c10

File tree

16 files changed

+99
-57
lines changed

16 files changed

+99
-57
lines changed

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,23 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
17-
SPIRV_LLVM_Translator_jll = "4a5d46fc-d8cf-5151-a261-86b458210efb"
17+
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
18+
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1920

2021
[compat]
2122
Adapt = "4"
2223
GPUArrays = "11.2.1"
23-
GPUCompiler = "0.27, 1"
24+
GPUCompiler = "1.2"
2425
KernelAbstractions = "0.9.2"
2526
LLVM = "9.1"
2627
LinearAlgebra = "1"
2728
OpenCL_jll = "=2024.10.24"
2829
Printf = "1"
2930
Random = "1"
3031
Reexport = "1"
31-
SPIRVIntrinsics = "0.2"
32-
SPIRV_LLVM_Translator_jll = "20"
32+
SPIRVIntrinsics = "0.3"
33+
SPIRV_LLVM_Backend_jll = "20"
34+
SPIRV_Tools_jll = "2024.4"
3335
StaticArrays = "1"
3436
julia = "1.10"

lib/intrinsics/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SPIRVIntrinsics"
22
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "0.2.1"
4+
version = "0.3.0"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

lib/intrinsics/src/atomic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
const atomic_integer_types = [UInt32, Int32]
99
# TODO: 64-bit atomics with ZE_DEVICE_MODULE_FLAG_INT64_ATOMICS
1010
# TODO: additional floating-point atomics with ZE_extension_float_atomics
11-
const atomic_memory_types = [AS.Local, AS.Global]
11+
const atomic_memory_types = [AS.Workgroup, AS.CrossWorkgroup]
1212

1313

1414
# generically typed

lib/intrinsics/src/math.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ for gentype in generic_types
9494
cosval = Ref{$gentype}()
9595
sinval = GC.@preserve cosval begin
9696
ptr = Base.unsafe_convert(Ptr{$gentype}, cosval)
97-
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Private}, ptr)
98-
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Private}), x, llvm_ptr)
97+
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Function}, ptr)
98+
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Function}), x, llvm_ptr)
9999
end
100100
return sinval, cosval[]
101101
end

lib/intrinsics/src/memory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
Context() do ctx
66
# XXX: as long as LLVMPtr is emitted as i8*, it doesn't make sense to type the GV
77
eltyp = convert(LLVMType, LLVM.Int8Type())
8-
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Local})
8+
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Workgroup})
99

1010
# create a function
1111
llvm_f, _ = create_function(T_ptr)
1212

1313
# create the global variable
1414
mod = LLVM.parent(llvm_f)
1515
gv_typ = LLVM.ArrayType(eltyp, len * sizeof(T))
16-
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Local)
16+
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Workgroup)
1717
if len > 0
1818
linkage!(gv, LLVM.API.LLVMInternalLinkage)
1919
initializer!(gv, null(gv_typ))
@@ -33,6 +33,6 @@
3333
ret!(builder, untyped_ptr)
3434
end
3535

36-
call_function(llvm_f, LLVMPtr{T,AS.Local})
36+
call_function(llvm_f, LLVMPtr{T,AS.Workgroup})
3737
end
3838
end

lib/intrinsics/src/pointer.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ export AS
44

55
module AS
66

7-
const Private = 0
8-
const Global = 1
9-
const Constant = 2
10-
const Local = 3
11-
const Generic = 4
12-
const Input = 5
13-
const Output = 6
14-
const Count = 7
7+
const Function = 0
8+
const CrossWorkgroup = 1
9+
const UniformConstant = 2
10+
const Workgroup = 3
11+
const Generic = 4
12+
const DeviceOnlyINTEL = 5 # XXX: should be CrossWorkgroup
13+
const HostOnlyINTEL = 6 # when USM is not supported
14+
const Input = 7
15+
const Output = 8
16+
const CodeSectionINTEL = 9
17+
const Private = 10
1518

1619
end

lib/intrinsics/src/printf.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
Context() do ctx
3232
T_void = LLVM.VoidType()
3333
T_int32 = LLVM.Int32Type()
34-
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.Constant)
34+
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.UniformConstant)
3535

3636
# create functions
3737
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
@@ -80,7 +80,7 @@ end
8080
push!(actual_args, actual_arg)
8181
end
8282

83-
str = globalstring_ptr!(builder, String(fmt); addrspace=AS.Constant)
83+
str = globalstring_ptr!(builder, String(fmt); addrspace=AS.UniformConstant)
8484

8585
# invoke printf and return
8686
printf_typ = LLVM.FunctionType(T_int32, [T_pint8]; vararg=true)

lib/intrinsics/src/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ macro builtin_ccall(name, ret, argtypes, args...)
3434
elt, as = T.parameters
3535

3636
# mangle address space
37-
ASstr = if as == AS.Global
37+
ASstr = if as == AS.CrossWorkgroup
3838
"CLglobal"
3939
#elseif as == AS.Global_device
4040
# "CLdevice"
4141
#elseif as == AS.Global_host
4242
# "CLhost"
43-
elseif as == AS.Local
43+
elseif as == AS.Workgroup
4444
"CLlocal"
45-
elseif as == AS.Constant
45+
elseif as == AS.UniformConstant
4646
"CLconstant"
47-
elseif as == AS.Private
47+
elseif as == AS.Function
4848
"CLprivate"
4949
elseif as == AS.Generic
5050
"CLgeneric"
5151
else
52-
error("Unknown address space $AS")
52+
error("Unknown address space $as")
5353
end
5454

5555
# encode as vendor qualifier

lib/intrinsics/src/work_item.jl

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,64 @@
11
# Work-Item Functions
2-
3-
export get_work_dim,
4-
get_global_size, get_global_id,
5-
get_local_size, get_enqueued_local_size, get_local_id,
6-
get_num_groups, get_group_id,
7-
get_global_offset,
8-
get_global_linear_id, get_local_linear_id
2+
#
3+
# https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
94

105
# NOTE: these functions now unsafely truncate to Int to avoid top bit checks.
116
# we should probably use range metadata instead.
127

13-
@device_function get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int
14-
15-
@device_function get_global_size(dimindx::Integer = 1u32) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx - 1u32) % Int
16-
@device_function get_global_id(dimindx::Integer = 1u32) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1
17-
18-
@device_function get_local_size(dimindx::Integer = 1u32) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx - 1u32) % Int
19-
@device_function get_enqueued_local_size(dimindx::Integer = 1u32) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx - 1u32) % Int
20-
@device_function get_local_id(dimindx::Integer = 1u32) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1
21-
22-
@device_function get_num_groups(dimindx::Integer = 1u32) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx - 1u32) % Int
23-
@device_function get_group_id(dimindx::Integer = 1u32) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1
24-
25-
@device_function get_global_offset(dimindx::Integer = 1u32) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx - 1u32) % Int + 1
8+
# 1D values
9+
for (julia_name, (spirv_name, julia_type, offset)) in [
10+
# indices
11+
:get_global_linear_id => (:BuiltInGlobalLinearId, Csize_t, 1),
12+
:get_local_linear_id => (:BuiltInLocalInvocationIndex, Csize_t, 1),
13+
:get_sub_group_id => (:BuiltInSubgroupId, UInt32, 1),
14+
:get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, UInt32, 1),
15+
# sizes
16+
:get_work_dim => (:BuiltInWorkDim, UInt32, 0),
17+
:get_sub_group_size => (:BuiltInSubgroupSize, UInt32, 0),
18+
:get_max_sub_group_size => (:BuiltInSubgroupMaxSize, UInt32, 0),
19+
:get_num_sub_groups => (:BuiltInNumSubgroups, UInt32, 0),
20+
:get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, UInt32, 0)]
21+
gvar_name = Symbol("@__spirv_$(spirv_name)")
22+
width = sizeof(julia_type) * 8
23+
@eval begin
24+
export $julia_name
25+
@device_function $julia_name() =
26+
Base.llvmcall(
27+
$("""$gvar_name = external addrspace($(AS.Input)) global i$(width)
28+
define i$(width) @entry() #0 {
29+
%val = load i$(width), i$(width) addrspace($(AS.Input))* $gvar_name
30+
ret i$(width) %val
31+
}
32+
attributes #0 = { alwaysinline }
33+
""", "entry"), $julia_type, Tuple{}) % Int + $offset
34+
end
35+
end
2636

27-
@device_function get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1
28-
@device_function get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1
37+
# 3D values
38+
for (julia_name, (spirv_name, offset)) in [
39+
# indices
40+
:get_global_id => (:BuiltInGlobalInvocationId, 1),
41+
:get_global_offset => (:BuiltInGlobalOffset, 1),
42+
:get_local_id => (:BuiltInLocalInvocationId, 1),
43+
:get_group_id => (:BuiltInWorkgroupId, 1),
44+
# sizes
45+
:get_global_size => (:BuiltInGlobalSize, 0),
46+
:get_local_size => (:BuiltInWorkgroupSize, 0),
47+
:get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0),
48+
:get_num_groups => (:BuiltInNumWorkgroups, 0)]
49+
gvar_name = Symbol("@__spirv_$(spirv_name)")
50+
width = Int === Int64 ? 64 : 32
51+
@eval begin
52+
export $julia_name
53+
@device_function $julia_name(dimindx::Integer=1u32) =
54+
Base.llvmcall(
55+
$("""$gvar_name = external addrspace($(AS.Input)) global <3 x i$(width)>
56+
define i$(width) @entry(i$(width) %idx) #0 {
57+
%val = load <3 x i$(width)>, <3 x i$(width)> addrspace($(AS.Input))* $gvar_name
58+
%element = extractelement <3 x i$(width)> %val, i$(width) %idx
59+
ret i$(width) %element
60+
}
61+
attributes #0 = { alwaysinline }
62+
""", "entry"), UInt, Tuple{UInt}, UInt(dimindx - 1u32)) % Int + $offset
63+
end
64+
end

src/OpenCL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OpenCL
22

33
using GPUCompiler
44
using LLVM, LLVM.Interop
5-
using SPIRV_LLVM_Translator_jll
5+
using SPIRV_LLVM_Backend_jll, SPIRV_Tools_jll
66
using Adapt
77
using Reexport
88
using GPUArrays

0 commit comments

Comments
 (0)