Skip to content

Commit 0e6ed99

Browse files
committed
Conform more strictly to the SPIR-V/OpenCL spec.
1 parent b69b2c5 commit 0e6ed99

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

lib/intrinsics/src/work_item.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
# Work-Item Functions
2+
#
3+
# https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
24

35
# NOTE: these functions now unsafely truncate to Int to avoid top bit checks.
46
# we should probably use range metadata instead.
57

68
# 1D values
7-
for (julia_name, (spirv_name, offset)) in [
9+
for (julia_name, (spirv_name, julia_type, offset)) in [
810
# indices
9-
:get_global_linear_id => (:BuiltInGlobalLinearId, 1u32),
10-
:get_local_linear_id => (:BuiltInLocalInvocationIndex, 1u32),
11-
:get_sub_group_id => (:BuiltInSubgroupId, 1u32),
12-
:get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, 1u32),
11+
:get_global_linear_id => (:BuiltInGlobalLinearId, Csize_t, 1),
12+
:get_local_linear_id => (:BuiltInLocalInvocationIndex, Csize_t, 1),
13+
:get_sub_group_id => (:BuiltInSubgroupId, UInt32, 1),
14+
:get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, UInt32, 1),
1315
# sizes
14-
:get_work_dim => (:BuiltInWorkDim, 0u32),
15-
:get_sub_group_size => (:BuiltInSubgroupSize, 0u32),
16-
:get_max_sub_group_size => (:BuiltInSubgroupMaxSize, 0u32),
17-
:get_num_sub_groups => (:BuiltInNumSubgroups, 0u32),
18-
:get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, 0u32)]
16+
:get_work_dim => (:BuiltInWorkDim, UInt32, 0),
17+
:get_sub_group_size => (:BuiltInSubgroupSize, UInt32, 0),
18+
:get_max_sub_group_size => (:BuiltInSubgroupMaxSize, UInt32, 0),
19+
:get_num_sub_groups => (:BuiltInNumSubgroups, UInt32, 0),
20+
:get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, UInt32, 0)]
1921
gvar_name = Symbol("@__spirv_$(spirv_name)")
20-
width = Int === Int64 ? 64 : 32
22+
width = sizeof(julia_type) * 8
2123
@eval begin
2224
export $julia_name
2325
@device_function $julia_name() =
@@ -28,27 +30,27 @@ for (julia_name, (spirv_name, offset)) in [
2830
ret i$(width) %val
2931
}
3032
attributes #0 = { alwaysinline }
31-
""", "entry"), UInt, Tuple{}) % Int + $offset
33+
""", "entry"), $julia_type, Tuple{}) % Int + $offset
3234
end
3335
end
3436

3537
# 3D values
3638
for (julia_name, (spirv_name, offset)) in [
3739
# indices
38-
:get_global_id => (:BuiltInGlobalInvocationId, 1u32),
39-
:get_global_offset => (:BuiltInGlobalOffset, 1u32),
40-
:get_local_id => (:BuiltInLocalInvocationId, 1u32),
41-
:get_group_id => (:BuiltInWorkgroupId, 1u32),
40+
:get_global_id => (:BuiltInGlobalInvocationId, 1),
41+
:get_global_offset => (:BuiltInGlobalOffset, 1),
42+
:get_local_id => (:BuiltInLocalInvocationId, 1),
43+
:get_group_id => (:BuiltInWorkgroupId, 1),
4244
# sizes
43-
:get_global_size => (:BuiltInGlobalSize, 0u32),
44-
:get_local_size => (:BuiltInWorkgroupSize, 0u32),
45-
:get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0u32),
46-
:get_num_groups => (:BuiltInNumWorkgroups, 0u32)]
45+
:get_global_size => (:BuiltInGlobalSize, 0),
46+
:get_local_size => (:BuiltInWorkgroupSize, 0),
47+
:get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0),
48+
:get_num_groups => (:BuiltInNumWorkgroups, 0)]
4749
gvar_name = Symbol("@__spirv_$(spirv_name)")
4850
width = Int === Int64 ? 64 : 32
4951
@eval begin
5052
export $julia_name
51-
@device_function $julia_name(dimindx::Integer=1) =
53+
@device_function $julia_name(dimindx::Integer=1u32) =
5254
Base.llvmcall(
5355
$("""$gvar_name = external addrspace($(AS.Input)) global <3 x i$(width)>
5456
define i$(width) @entry(i$(width) %idx) #0 {

src/compiler/compilation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
supports_fp64 = "cl_khr_fp64" in dev.extensions
4848

4949
# create GPUCompiler objects
50-
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, kwargs...)
50+
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...)
5151
params = OpenCLCompilerParams()
5252
CompilerConfig(target, params; kernel, name, always_inline)
5353
end

0 commit comments

Comments
 (0)