1
- #include "nbl/builtin/hlsl/bda/__ptr .hlsl"
1
+ #include "nbl/builtin/hlsl/bda/bda_accessor .hlsl"
2
2
#include "nbl/builtin/hlsl/sort/counting.hlsl"
3
3
#include "app_resources/common.hlsl"
4
4
5
5
[[vk::push_constant]] CountingPushData pushData;
6
6
7
- struct PtrAccessor
8
- {
9
- static PtrAccessor create (const uint64_t addr)
10
- {
11
- PtrAccessor ptr;
12
- ptr.addr = addr;
13
- return ptr;
14
- }
15
-
16
- uint32_t get (const uint64_t index)
17
- {
18
- return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
19
-
20
- deref ().load ();
21
- }
22
-
23
- void set (const uint64_t index, const uint32_t value)
24
- {
25
- nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
26
-
27
- deref ().store (value);
28
- }
29
-
30
- uint32_t atomicAdd (const uint64_t index, const uint32_t value)
31
- {
32
- nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
33
-
34
- deref ().get_ptr ();
35
-
36
- return nbl::hlsl::glsl::atomicAdd (ptr, value);
37
- }
38
-
39
- uint32_t atomicSub (const uint64_t index, const uint32_t value)
40
- {
41
- return atomicAdd (index, (uint32_t) (-1 * value));
42
- }
43
-
44
- uint64_t addr;
45
- };
7
+ using PtrAccessor = nbl::hlsl::bda::BdaAccessor < uint32_t >;
46
8
47
9
groupshared uint32_t sdata[BucketCount];
48
10
@@ -71,27 +33,25 @@ struct SharedAccessor
71
33
72
34
struct DoublePtrAccessor
73
35
{
74
- static DoublePtrAccessor create (const uint64_t in_addr , const uint64_t out_addr )
36
+ static DoublePtrAccessor create (const PtrAccessor input , const PtrAccessor output )
75
37
{
76
- DoublePtrAccessor ptr ;
77
- ptr.in_addr = in_addr ;
78
- ptr.out_addr = out_addr ;
79
- return ptr ;
38
+ DoublePtrAccessor accessor ;
39
+ accessor.input = input ;
40
+ accessor.output = output ;
41
+ return accessor ;
80
42
}
81
43
82
44
uint32_t get (const uint64_t index)
83
45
{
84
- return nbl::hlsl::bda::__ptr < uint32_t > (in_addr + sizeof (uint32_t) * index).template
85
- deref ().load ();
46
+ return input.get (index);
86
47
}
87
48
88
49
void set (const uint64_t index, const uint32_t value)
89
50
{
90
- nbl::hlsl::bda::__ptr < uint32_t > (out_addr + sizeof (uint32_t) * index).template
91
- deref ().store (value);
51
+ output.set (index, value);
92
52
}
93
53
94
- uint64_t in_addr, out_addr ;
54
+ PtrAccessor input, output ;
95
55
};
96
56
97
57
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize ()
@@ -109,8 +69,14 @@ void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
109
69
params.maximum = pushData.maximum;
110
70
111
71
nbl::hlsl::sort::counting <WorkgroupSize, BucketCount, uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor, SharedAccessor > counter;
112
- DoublePtrAccessor key_accessor = DoublePtrAccessor::create (pushData.inputKeyAddress, pushData.outputKeyAddress);
113
- DoublePtrAccessor value_accessor = DoublePtrAccessor::create (pushData.inputValueAddress, pushData.outputValueAddress);
72
+ DoublePtrAccessor key_accessor = DoublePtrAccessor::create (
73
+ PtrAccessor::create (pushData.inputKeyAddress),
74
+ PtrAccessor::create (pushData.outputKeyAddress)
75
+ );
76
+ DoublePtrAccessor value_accessor = DoublePtrAccessor::create (
77
+ PtrAccessor::create (pushData.inputValueAddress),
78
+ PtrAccessor::create (pushData.outputValueAddress)
79
+ );
114
80
PtrAccessor scratch_accessor = PtrAccessor::create (pushData.scratchAddress);
115
81
SharedAccessor shared_accessor;
116
82
counter.scatter (
0 commit comments