Skip to content

Commit 29cb3f2

Browse files
committed
Make use of BdaAccessor
1 parent 83181a9 commit 29cb3f2

File tree

2 files changed

+20
-89
lines changed

2 files changed

+20
-89
lines changed

CountingSort/app_resources/prefix_sum_shader.comp.hlsl

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,10 @@
1-
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
1+
#include "nbl/builtin/hlsl/bda/bda_accessor.hlsl"
22
#include "nbl/builtin/hlsl/sort/counting.hlsl"
33
#include "app_resources/common.hlsl"
44

55
[[vk::push_constant]] CountingPushData pushData;
66

7-
struct PtrAccessor
8-
{
9-
static PtrAccessor create(const uint64_t addr)
10-
{
11-
PtrAccessor ptr;
12-
ptr.addr = addr;
13-
return ptr;
14-
}
15-
16-
uint32_t get(const uint64_t index)
17-
{
18-
return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
19-
deref().load();
20-
}
21-
22-
void set(const uint64_t index, const uint32_t value)
23-
{
24-
nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
25-
deref().store(value);
26-
}
27-
28-
uint32_t atomicAdd(const uint64_t index, const uint32_t value)
29-
{
30-
nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
31-
deref().get_ptr();
32-
33-
return nbl::hlsl::glsl::atomicAdd(ptr, value);
34-
}
35-
36-
uint32_t atomicSub(const uint64_t index, const uint32_t value)
37-
{
38-
return atomicAdd(index, (uint32_t) (-1 * value));
39-
}
40-
41-
uint64_t addr;
42-
};
7+
using PtrAccessor = nbl::hlsl::bda::BdaAccessor < uint32_t >;
438

449
groupshared uint32_t sdata[BucketCount];
4510

CountingSort/app_resources/scatter_shader.comp.hlsl

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,10 @@
1-
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
1+
#include "nbl/builtin/hlsl/bda/bda_accessor.hlsl"
22
#include "nbl/builtin/hlsl/sort/counting.hlsl"
33
#include "app_resources/common.hlsl"
44

55
[[vk::push_constant]] CountingPushData pushData;
66

7-
struct PtrAccessor
8-
{
9-
static PtrAccessor create(const uint64_t addr)
10-
{
11-
PtrAccessor ptr;
12-
ptr.addr = addr;
13-
return ptr;
14-
}
15-
16-
uint32_t get(const uint64_t index)
17-
{
18-
return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
19-
20-
deref().load();
21-
}
22-
23-
void set(const uint64_t index, const uint32_t value)
24-
{
25-
nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
26-
27-
deref().store(value);
28-
}
29-
30-
uint32_t atomicAdd(const uint64_t index, const uint32_t value)
31-
{
32-
nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
33-
34-
deref().get_ptr();
35-
36-
return nbl::hlsl::glsl::atomicAdd(ptr, value);
37-
}
38-
39-
uint32_t atomicSub(const uint64_t index, const uint32_t value)
40-
{
41-
return atomicAdd(index, (uint32_t) (-1 * value));
42-
}
43-
44-
uint64_t addr;
45-
};
7+
using PtrAccessor = nbl::hlsl::bda::BdaAccessor < uint32_t >;
468

479
groupshared uint32_t sdata[BucketCount];
4810

@@ -71,27 +33,25 @@ struct SharedAccessor
7133

7234
struct DoublePtrAccessor
7335
{
74-
static DoublePtrAccessor create(const uint64_t in_addr, const uint64_t out_addr)
36+
static DoublePtrAccessor create(const PtrAccessor input, const PtrAccessor output)
7537
{
76-
DoublePtrAccessor ptr;
77-
ptr.in_addr = in_addr;
78-
ptr.out_addr = out_addr;
79-
return ptr;
38+
DoublePtrAccessor accessor;
39+
accessor.input = input;
40+
accessor.output = output;
41+
return accessor;
8042
}
8143

8244
uint32_t get(const uint64_t index)
8345
{
84-
return nbl::hlsl::bda::__ptr < uint32_t > (in_addr + sizeof(uint32_t) * index).template
85-
deref().load();
46+
return input.get(index);
8647
}
8748

8849
void set(const uint64_t index, const uint32_t value)
8950
{
90-
nbl::hlsl::bda::__ptr < uint32_t > (out_addr + sizeof(uint32_t) * index).template
91-
deref().store(value);
51+
output.set(index, value);
9252
}
9353

94-
uint64_t in_addr, out_addr;
54+
PtrAccessor input, output;
9555
};
9656

9757
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
@@ -109,8 +69,14 @@ void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
10969
params.maximum = pushData.maximum;
11070

11171
nbl::hlsl::sort::counting <WorkgroupSize, BucketCount, uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor, SharedAccessor > counter;
112-
DoublePtrAccessor key_accessor = DoublePtrAccessor::create(pushData.inputKeyAddress, pushData.outputKeyAddress);
113-
DoublePtrAccessor value_accessor = DoublePtrAccessor::create(pushData.inputValueAddress, pushData.outputValueAddress);
72+
DoublePtrAccessor key_accessor = DoublePtrAccessor::create(
73+
PtrAccessor::create(pushData.inputKeyAddress),
74+
PtrAccessor::create(pushData.outputKeyAddress)
75+
);
76+
DoublePtrAccessor value_accessor = DoublePtrAccessor::create(
77+
PtrAccessor::create(pushData.inputValueAddress),
78+
PtrAccessor::create(pushData.outputValueAddress)
79+
);
11480
PtrAccessor scratch_accessor = PtrAccessor::create(pushData.scratchAddress);
11581
SharedAccessor shared_accessor;
11682
counter.scatter(

0 commit comments

Comments
 (0)