Skip to content

Commit a1747c6

Browse files
committed
Work on HLSL impl of property pools
1 parent 64cbb65 commit a1747c6

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "nbl/builtin/hlsl/jit/device_capabilities.hlsl"
2+
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"
3+
4+
namespace nbl
5+
{
6+
namespace hlsl
7+
{
8+
namespace property_pools
9+
{
10+
// https://github.com/microsoft/DirectXShaderCompiler/issues/6144
11+
template<typename capability_traits=nbl::hlsl::jit::device_capabilities_traits>
12+
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize() {
13+
return uint32_t3(capability_traits::maxOptimallyResidentWorkgroupInvocations, 1, 1);
14+
}
15+
16+
[[vk::push_constant]] GlobalPushContants globals;
17+
18+
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2>
19+
struct TransferLoop
20+
{
21+
void iteration(uint propertyId, uint64_t propertySize, uint64_t srcAddr, uint64_t dstAddr, uint invocationIndex)
22+
{
23+
const uint srcOffset = uint64_t(invocationIndex) * (uint64_t(1) << SrcIndexSizeLog2) * propertySize;
24+
const uint dstOffset = uint64_t(invocationIndex) * (uint64_t(1) << DstIndexSizeLog2) * propertySize;
25+
26+
const uint srcIndexAddress = Fill ? srcAddr + srcOffset : srcAddr;
27+
const uint dstIndexAddress = Fill ? dstAddr + dstOffset : dstAddr;
28+
29+
const uint srcAddressMapped = SrcIndexIota ? srcIndexAddress : vk::RawBufferLoad<uint64_t>(srcIndexAddress);
30+
const uint dstAddressMapped = DstIndexIota ? dstIndexAddress : vk::RawBufferLoad<uint64_t>(dstIndexAddress);
31+
32+
if (SrcIndexSizeLog2 == 0) {} // we can't write individual bytes
33+
else if (SrcIndexSizeLog2 == 1) vk::RawBufferStore<uint16_t>(dstAddressMapped, vk::RawBufferLoad<uint16_t>(srcAddressMapped));
34+
else if (SrcIndexSizeLog2 == 2) vk::RawBufferStore<uint32_t>(dstAddressMapped, vk::RawBufferLoad<uint32_t>(srcAddressMapped));
35+
else if (SrcIndexSizeLog2 == 3) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped));
36+
}
37+
38+
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
39+
{
40+
uint lastInvocation = min(transferRequest.elementCount, gloabls.endOffset);
41+
for (uint invocationIndex = globals.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
42+
{
43+
iteration(propertyId, transferRequest.propertySize, transferRequest.srcAddr, transferRequest.dstAddr, invocationIndex);
44+
}
45+
}
46+
}
47+
48+
// For creating permutations of the functions based on parameters that are constant over the transfer request
49+
// These branches should all be scalar, and because of how templates work, the loops shouldn't have any
50+
// branching within them
51+
52+
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
53+
struct TransferLoopPermutationSrcIndexSizeLog
54+
{
55+
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
56+
{
57+
if (transferRequest.dstIndexSizeLog2 == 0) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
58+
else if (transferRequest.dstIndexSizeLog2 == 1) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
59+
else if (transferRequest.dstIndexSizeLog2 == 2) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
60+
else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
61+
}
62+
}
63+
64+
template<bool Fill, bool SrcIndexIota, bool DstIndexIota>
65+
struct TransferLoopPermutationDstIota
66+
{
67+
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
68+
{
69+
if (transferRequest.srcIndexSizeLog2 == 0) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
70+
else if (transferRequest.srcIndexSizeLog2 == 1) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
71+
else if (transferRequest.srcIndexSizeLog2 == 2) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
72+
else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
73+
}
74+
}
75+
76+
template<bool Fill, bool SrcIndexIota>
77+
struct TransferLoopPermutationSrcIota
78+
{
79+
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
80+
{
81+
bool dstIota = transferRequest.dstAddr == 0;
82+
if (dstIota) TransferLoopPermutationDstIota<Fill, SrcIndexIota, true>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
83+
else TransferLoopPermutationDstIota<Fill, SrcIndexIota, false>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
84+
}
85+
}
86+
87+
template<bool Fill>
88+
struct TransferLoopPermutationFill
89+
{
90+
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
91+
{
92+
bool srcIota = transferRequest.srcAddr == 0;
93+
if (srcIota) TransferLoopPermutationSrcIota<Fill, true>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
94+
else TransferLoopPermutationSrcIota<Fill, false>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize);
95+
}
96+
}
97+
98+
void main(uint32_t3 dispatchId : SV_DispatchThreadID)
99+
{
100+
const uint propertyId = dispatchId.y;
101+
const uint invocationIndex = dispatchId.x;
102+
103+
// Loading transfer request from the pointer (can't use struct
104+
// with BDA on HLSL SPIRV)
105+
const TransferRequest transferRequest;
106+
transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress);
107+
transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t));
108+
transferRequest.srcIndexAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 2);
109+
transferRequest.dstIndexAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 3);
110+
// TODO: These are all part of the same bitfield and shoulbe read with a single RawBufferLoad
111+
transferRequest.elementCount = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 4);
112+
transferRequest.propertySize = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 5);
113+
transferRequest.fill = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 6);
114+
transferRequest.srcIndexSizeLog2 = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 7);
115+
transferRequest.dstIndexSizeLog2 = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 8);
116+
117+
const uint dispatchSize = capability_traits::maxOptimallyResidentWorkgroupInvocations;
118+
const bool fill = transferRequest.fill == 1;
119+
120+
if (fill) TransferLoopPermutationFill<true>.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize);
121+
else TransferLoopPermutationFill<false>.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize);
122+
}
123+
124+
}
125+
}
126+
}
127+
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
namespace nbl
2+
{
3+
namespace hlsl
4+
{
5+
namespace property_pools
6+
{
7+
8+
struct TransferRequest
9+
{
10+
// This represents a transfer command/request
11+
uint64_t srcAddr;
12+
uint64_t dstAddr;
13+
uint64_t srcIndexAddr = 0; // IOTA default
14+
uint64_t dstIndexAddr = 0; // IOTA default
15+
uint64_t elementCount : 35; // allow up to 64GB IGPUBuffers
16+
uint64_t propertySize : 24; // all the leftover bits (just use bytes now)
17+
uint64_t fill : 1 = false;
18+
// 0=uint8, 1=uint16, 2=uint32, 3=uint64
19+
uint64_t srcIndexSizeLog2 : 2 = 1;
20+
uint64_t dstIndexSizeLog2 : 2 = 1;
21+
};
22+
23+
struct GlobalPushContants
24+
{
25+
// Define the range of invocations (X axis) that will be transfered over in this dispatch
26+
// May be sectioned off in the case of overflow or any other situation that doesn't allow
27+
// for a full transfer
28+
uint64_t beginOffset;
29+
uint64_t endOffset;
30+
// BDA address (GPU pointer) into the transfer commands buffer
31+
uint64_t transferCommandsAddress;
32+
};
33+
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)