6
6
7
7
struct PtrAccessor
8
8
{
9
- static PtrAccessor create (uint64_t addr)
9
+ static PtrAccessor create (const uint64_t addr)
10
10
{
11
11
PtrAccessor ptr;
12
12
ptr.addr = addr;
13
13
return ptr;
14
14
}
15
15
16
- uint32_t get (uint64_t index)
16
+ uint32_t get (const uint64_t index)
17
17
{
18
18
return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
19
+
19
20
deref ().load ();
20
21
}
21
22
22
- void set (uint64_t index, uint32_t value)
23
+ void set (const uint64_t index, const uint32_t value)
23
24
{
24
25
nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
26
+
25
27
deref ().store (value);
26
28
}
27
29
28
- uint32_t atomicAdd (uint64_t index, uint32_t value)
30
+ uint32_t atomicAdd (const uint64_t index, const uint32_t value)
29
31
{
30
32
nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
33
+
31
34
deref ().get_ptr ();
32
35
33
36
return nbl::hlsl::glsl::atomicAdd (ptr, value);
@@ -36,23 +39,48 @@ struct PtrAccessor
36
39
uint64_t addr;
37
40
};
38
41
42
+ groupshared uint32_t sdata[BucketCount];
43
+
44
+ struct SharedAccessor
45
+ {
46
+ uint32_t get (const uint32_t index)
47
+ {
48
+ return sdata[index];
49
+ }
50
+
51
+ void set (const uint32_t index, const uint32_t value)
52
+ {
53
+ sdata[index] = value;
54
+ }
55
+
56
+ uint32_t atomicAdd (const uint32_t index, const uint32_t value)
57
+ {
58
+ return nbl::hlsl::glsl::atomicAdd (sdata[index], value);
59
+ }
60
+
61
+ void workgroupExecutionAndMemoryBarrier ()
62
+ {
63
+ nbl::hlsl::glsl::barrier ();
64
+ }
65
+ };
66
+
39
67
struct DoublePtrAccessor
40
68
{
41
- static DoublePtrAccessor create (uint64_t in_addr, uint64_t out_addr)
69
+ static DoublePtrAccessor create (const uint64_t in_addr, const uint64_t out_addr)
42
70
{
43
71
DoublePtrAccessor ptr;
44
72
ptr.in_addr = in_addr;
45
73
ptr.out_addr = out_addr;
46
74
return ptr;
47
75
}
48
76
49
- uint32_t get (uint64_t index)
77
+ uint32_t get (const uint64_t index)
50
78
{
51
79
return nbl::hlsl::bda::__ptr < uint32_t > (in_addr + sizeof (uint32_t) * index).template
52
80
deref ().load ();
53
81
}
54
82
55
- void set (uint64_t index, uint32_t value)
83
+ void set (const uint64_t index, const uint32_t value)
56
84
{
57
85
nbl::hlsl::bda::__ptr < uint32_t > (out_addr + sizeof (uint32_t) * index).template
58
86
deref ().store (value);
@@ -70,19 +98,23 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
70
98
void main (uint32_t3 ID : SV_GroupThreadID , uint32_t3 GroupID : SV_GroupID )
71
99
{
72
100
nbl::hlsl::sort::CountingParameters < uint32_t > params;
101
+ params.workgroupSize = WorkgroupSize;
102
+ params.bucketCount = BucketCount;
73
103
params.dataElementCount = pushData.dataElementCount;
74
104
params.elementsPerWT = pushData.elementsPerWT;
75
105
params.minimum = pushData.minimum;
76
106
params.maximum = pushData.maximum;
77
107
78
- nbl::hlsl::sort::counting <uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor> counter;
108
+ nbl::hlsl::sort::counting <WorkgroupSize, BucketCount, uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor, SharedAccessor > counter;
79
109
DoublePtrAccessor key_accessor = DoublePtrAccessor::create (pushData.inputKeyAddress, pushData.outputKeyAddress);
80
110
DoublePtrAccessor value_accessor = DoublePtrAccessor::create (pushData.inputValueAddress, pushData.outputValueAddress);
81
111
PtrAccessor scratch_accessor = PtrAccessor::create (pushData.scratchAddress);
112
+ SharedAccessor shared_accessor;
82
113
counter.scatter (
83
114
key_accessor,
84
115
value_accessor,
85
116
scratch_accessor,
117
+ shared_accessor,
86
118
params
87
119
);
88
120
}
0 commit comments