Skip to content

Commit ea19986

Browse files
committed
Port work item intrinsics.
1 parent 98997cf commit ea19986

File tree

1 file changed

+55
-23
lines changed

1 file changed

+55
-23
lines changed

lib/intrinsics/src/work_item.jl

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,60 @@
11
# Work-Item Functions
22

3-
export get_work_dim,
4-
get_global_size, get_global_id,
5-
get_local_size, get_enqueued_local_size, get_local_id,
6-
get_num_groups, get_group_id,
7-
get_global_offset,
8-
get_global_linear_id, get_local_linear_id
9-
103
# NOTE: these functions now unsafely truncate to Int to avoid top bit checks.
114
# we should probably use range metadata instead.
125

13-
@device_function get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int
14-
15-
@device_function get_global_size(dimindx::Integer=1) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx-1) % Int
16-
@device_function get_global_id(dimindx::Integer=1) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx-1) % Int + 1
17-
18-
@device_function get_local_size(dimindx::Integer=1) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx-1) % Int
19-
@device_function get_enqueued_local_size(dimindx::Integer=1) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx-1) % Int
20-
@device_function get_local_id(dimindx::Integer=1) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx-1) % Int + 1
21-
22-
@device_function get_num_groups(dimindx::Integer=1) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx-1) % Int
23-
@device_function get_group_id(dimindx::Integer=1) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx-1) % Int + 1
24-
25-
@device_function get_global_offset(dimindx::Integer=1) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx-1) % Int + 1
26-
27-
@device_function get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1
28-
@device_function get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1
6+
# 1D values
7+
for (julia_name, (spirv_name, offset)) in [
8+
# indices
9+
:get_global_linear_id => (:BuiltInGlobalLinearId, 1),
10+
:get_local_linear_id => (:BuiltInLocalInvocationIndex, 1),
11+
:get_sub_group_id => (:BuiltInSubgroupId, 1),
12+
:get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, 1),
13+
# sizes
14+
:get_work_dim => (:BuiltInWorkDim, 0),
15+
:get_sub_group_size => (:BuiltInSubgroupSize, 0),
16+
:get_max_sub_group_size => (:BuiltInSubgroupMaxSize, 0),
17+
:get_num_sub_groups => (:BuiltInNumSubgroups, 0),
18+
:get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, 0)]
19+
gvar_name = Symbol("@__spirv_$(spirv_name)")
20+
@eval begin
21+
export $julia_name
22+
@device_function $julia_name() =
23+
Base.llvmcall(
24+
$("""$gvar_name = external addrspace(1) global i32
25+
define i32 @entry() #0 {
26+
%val = load i32, i32 addrspace(1)* $gvar_name
27+
ret i32 %val
28+
}
29+
attributes #0 = { alwaysinline }
30+
""", "entry"), UInt32, Tuple{}) % Int + $offset
31+
end
32+
end
33+
34+
# 3D values
35+
for (julia_name, (spirv_name, offset)) in [
36+
# indices
37+
:get_global_id => (:BuiltInGlobalInvocationId, 1),
38+
:get_global_offset => (:BuiltInGlobalOffset, 1),
39+
:get_local_id => (:BuiltInLocalInvocationId, 1),
40+
:get_group_id => (:BuiltInWorkgroupId, 1),
41+
# sizes
42+
:get_global_size => (:BuiltInGlobalSize, 0),
43+
:get_local_size => (:BuiltInWorkgroupSize, 0),
44+
:get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0),
45+
:get_num_groups => (:BuiltInNumWorkgroups, 0)]
46+
gvar_name = Symbol("@__spirv_$(spirv_name)")
47+
@eval begin
48+
export $julia_name
49+
@device_function $julia_name(dimindx::Integer=1) =
50+
Base.llvmcall(
51+
$("""$gvar_name = external addrspace(1) global <3 x i32>
52+
define i32 @entry(i32 %idx) #0 {
53+
%val = load <3 x i32>, <3 x i32> addrspace(1)* $gvar_name
54+
%element = extractelement <3 x i32> %val, i32 %idx
55+
ret i32 %element
56+
}
57+
attributes #0 = { alwaysinline }
58+
""", "entry"), UInt32, Tuple{UInt32}, UInt32(dimindx - 1)) % Int + $offset
59+
end
60+
end

0 commit comments

Comments
 (0)