Skip to content

Commit b00e75c

Browse files
committed
Move the globals into userspace code
1 parent f7d0cc3 commit b00e75c

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

CountingSort/app_resources/prefix_sum_shader.comp.hlsl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66

77
struct PtrAccessor
88
{
9-
static PtrAccessor create(uint64_t addr)
9+
static PtrAccessor create(const uint64_t addr)
1010
{
1111
PtrAccessor ptr;
1212
ptr.addr = addr;
1313
return ptr;
1414
}
1515

16-
uint32_t get(uint64_t index)
16+
uint32_t get(const uint64_t index)
1717
{
1818
return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
1919
deref().load();
2020
}
2121

22-
void set(uint64_t index, uint32_t value)
22+
void set(const uint64_t index, const uint32_t value)
2323
{
2424
nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
2525
deref().store(value);
2626
}
2727

28-
uint32_t atomicAdd(uint64_t index, uint32_t value)
28+
uint32_t atomicAdd(const uint64_t index, const uint32_t value)
2929
{
3030
nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
3131
deref().get_ptr();
@@ -36,6 +36,31 @@ struct PtrAccessor
3636
uint64_t addr;
3737
};
3838

39+
groupshared uint32_t sdata[BucketCount];
40+
41+
struct SharedAccessor
42+
{
43+
uint32_t get(const uint32_t index)
44+
{
45+
return sdata[index];
46+
}
47+
48+
void set(const uint32_t index, const uint32_t value)
49+
{
50+
sdata[index] = value;
51+
}
52+
53+
uint32_t atomicAdd(const uint32_t index, const uint32_t value)
54+
{
55+
return nbl::hlsl::glsl::atomicAdd(sdata[index], value);
56+
}
57+
58+
void workgroupExecutionAndMemoryBarrier()
59+
{
60+
nbl::hlsl::glsl::barrier();
61+
}
62+
};
63+
3964
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
4065
{
4166
return uint32_t3(WorkgroupSize, 1, 1);
@@ -45,17 +70,21 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
4570
void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
4671
{
4772
nbl::hlsl::sort::CountingParameters < uint32_t > params;
73+
params.workgroupSize = WorkgroupSize;
74+
params.bucketCount = BucketCount;
4875
params.dataElementCount = pushData.dataElementCount;
4976
params.elementsPerWT = pushData.elementsPerWT;
5077
params.minimum = pushData.minimum;
5178
params.maximum = pushData.maximum;
5279

53-
nbl::hlsl::sort::counting <uint32_t, PtrAccessor, PtrAccessor, PtrAccessor> counter;
80+
nbl::hlsl::sort::counting <WorkgroupSize, BucketCount, uint32_t, PtrAccessor, PtrAccessor, PtrAccessor, SharedAccessor> counter;
5481
PtrAccessor input_accessor = PtrAccessor::create(pushData.inputKeyAddress);
5582
PtrAccessor scratch_accessor = PtrAccessor::create(pushData.scratchAddress);
83+
SharedAccessor shared_accessor;
5684
counter.histogram(
5785
input_accessor,
5886
scratch_accessor,
87+
shared_accessor,
5988
params
6089
);
6190
}

CountingSort/app_resources/scatter_shader.comp.hlsl

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,31 @@
66

77
struct PtrAccessor
88
{
9-
static PtrAccessor create(uint64_t addr)
9+
static PtrAccessor create(const uint64_t addr)
1010
{
1111
PtrAccessor ptr;
1212
ptr.addr = addr;
1313
return ptr;
1414
}
1515

16-
uint32_t get(uint64_t index)
16+
uint32_t get(const uint64_t index)
1717
{
1818
return nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
19+
1920
deref().load();
2021
}
2122

22-
void set(uint64_t index, uint32_t value)
23+
void set(const uint64_t index, const uint32_t value)
2324
{
2425
nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
26+
2527
deref().store(value);
2628
}
2729

28-
uint32_t atomicAdd(uint64_t index, uint32_t value)
30+
uint32_t atomicAdd(const uint64_t index, const uint32_t value)
2931
{
3032
nbl::hlsl::bda::__spv_ptr_t < uint32_t > ptr = nbl::hlsl::bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
33+
3134
deref().get_ptr();
3235

3336
return nbl::hlsl::glsl::atomicAdd(ptr, value);
@@ -36,23 +39,48 @@ struct PtrAccessor
3639
uint64_t addr;
3740
};
3841

42+
groupshared uint32_t sdata[BucketCount];
43+
44+
struct SharedAccessor
45+
{
46+
uint32_t get(const uint32_t index)
47+
{
48+
return sdata[index];
49+
}
50+
51+
void set(const uint32_t index, const uint32_t value)
52+
{
53+
sdata[index] = value;
54+
}
55+
56+
uint32_t atomicAdd(const uint32_t index, const uint32_t value)
57+
{
58+
return nbl::hlsl::glsl::atomicAdd(sdata[index], value);
59+
}
60+
61+
void workgroupExecutionAndMemoryBarrier()
62+
{
63+
nbl::hlsl::glsl::barrier();
64+
}
65+
};
66+
3967
struct DoublePtrAccessor
4068
{
41-
static DoublePtrAccessor create(uint64_t in_addr, uint64_t out_addr)
69+
static DoublePtrAccessor create(const uint64_t in_addr, const uint64_t out_addr)
4270
{
4371
DoublePtrAccessor ptr;
4472
ptr.in_addr = in_addr;
4573
ptr.out_addr = out_addr;
4674
return ptr;
4775
}
4876

49-
uint32_t get(uint64_t index)
77+
uint32_t get(const uint64_t index)
5078
{
5179
return nbl::hlsl::bda::__ptr < uint32_t > (in_addr + sizeof(uint32_t) * index).template
5280
deref().load();
5381
}
5482

55-
void set(uint64_t index, uint32_t value)
83+
void set(const uint64_t index, const uint32_t value)
5684
{
5785
nbl::hlsl::bda::__ptr < uint32_t > (out_addr + sizeof(uint32_t) * index).template
5886
deref().store(value);
@@ -70,19 +98,23 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
7098
void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
7199
{
72100
nbl::hlsl::sort::CountingParameters < uint32_t > params;
101+
params.workgroupSize = WorkgroupSize;
102+
params.bucketCount = BucketCount;
73103
params.dataElementCount = pushData.dataElementCount;
74104
params.elementsPerWT = pushData.elementsPerWT;
75105
params.minimum = pushData.minimum;
76106
params.maximum = pushData.maximum;
77107

78-
nbl::hlsl::sort::counting <uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor> counter;
108+
nbl::hlsl::sort::counting <WorkgroupSize, BucketCount, uint32_t, DoublePtrAccessor, DoublePtrAccessor, PtrAccessor, SharedAccessor > counter;
79109
DoublePtrAccessor key_accessor = DoublePtrAccessor::create(pushData.inputKeyAddress, pushData.outputKeyAddress);
80110
DoublePtrAccessor value_accessor = DoublePtrAccessor::create(pushData.inputValueAddress, pushData.outputValueAddress);
81111
PtrAccessor scratch_accessor = PtrAccessor::create(pushData.scratchAddress);
112+
SharedAccessor shared_accessor;
82113
counter.scatter(
83114
key_accessor,
84115
value_accessor,
85116
scratch_accessor,
117+
shared_accessor,
86118
params
87119
);
88120
}

0 commit comments

Comments
 (0)