@@ -9,7 +9,7 @@ namespace hlsl
9
9
namespace property_pools
10
10
{
11
11
12
- [[vk::push_constant]] GlobalPushContants globals;
12
+ [[vk::push_constant]] TransferDispatchInfo globals;
13
13
14
14
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2>
15
15
struct TransferLoop
@@ -39,12 +39,12 @@ struct TransferLoop
39
39
else if (SrcIndexSizeLog2 == 3 ) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped));
40
40
}
41
41
42
- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
42
+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
43
43
{
44
44
uint64_t elementCount = uint64_t (transferRequest.elementCount32)
45
45
| uint64_t (transferRequest.elementCountExtra) << 32 ;
46
- uint64_t lastInvocation = min (elementCount, globals .endOffset);
47
- for (uint64_t invocationIndex = globals .beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
46
+ uint64_t lastInvocation = min (elementCount, dispatchInfo .endOffset);
47
+ for (uint64_t invocationIndex = dispatchInfo .beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
48
48
{
49
49
iteration (propertyId, transferRequest, invocationIndex);
50
50
}
@@ -62,58 +62,53 @@ struct TransferLoop
62
62
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
63
63
struct TransferLoopPermutationSrcIndexSizeLog
64
64
{
65
- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
65
+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
66
66
{
67
- if (transferRequest.dstIndexSizeLog2 == 0 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68
- else if (transferRequest.dstIndexSizeLog2 == 1 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69
- else if (transferRequest.dstIndexSizeLog2 == 2 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70
- else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
67
+ if (transferRequest.dstIndexSizeLog2 == 0 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68
+ else if (transferRequest.dstIndexSizeLog2 == 1 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69
+ else if (transferRequest.dstIndexSizeLog2 == 2 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70
+ else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
71
71
}
72
72
};
73
73
74
74
template<bool Fill, bool SrcIndexIota, bool DstIndexIota>
75
75
struct TransferLoopPermutationDstIota
76
76
{
77
- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
77
+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
78
78
{
79
- if (transferRequest.srcIndexSizeLog2 == 0 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80
- else if (transferRequest.srcIndexSizeLog2 == 1 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81
- else if (transferRequest.srcIndexSizeLog2 == 2 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82
- else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
79
+ if (transferRequest.srcIndexSizeLog2 == 0 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80
+ else if (transferRequest.srcIndexSizeLog2 == 1 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81
+ else if (transferRequest.srcIndexSizeLog2 == 2 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82
+ else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
83
83
}
84
84
};
85
85
86
86
template<bool Fill, bool SrcIndexIota>
87
87
struct TransferLoopPermutationSrcIota
88
88
{
89
- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
89
+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
90
90
{
91
91
bool dstIota = transferRequest.dstIndexAddr == 0 ;
92
- if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93
- else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
92
+ if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93
+ else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
94
94
}
95
95
};
96
96
97
97
template<bool Fill>
98
98
struct TransferLoopPermutationFill
99
99
{
100
- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
100
+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
101
101
{
102
102
bool srcIota = transferRequest.srcIndexAddr == 0 ;
103
- if (srcIota) { TransferLoopPermutationSrcIota<Fill, true > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104
- else { TransferLoopPermutationSrcIota<Fill, false > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
103
+ if (srcIota) { TransferLoopPermutationSrcIota<Fill, true > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104
+ else { TransferLoopPermutationSrcIota<Fill, false > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
105
105
}
106
106
};
107
107
108
- template<typename device_capabilities>
109
- void main (uint32_t3 dispatchId)
110
- {
111
- const uint propertyId = dispatchId.y;
112
- const uint invocationIndex = dispatchId.x;
113
-
114
- // Loading transfer request from the pointer (can't use struct
115
- // with BDA on HLSL SPIRV)
116
- uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof (TransferRequest) * propertyId;
108
+ // Loading transfer request from the pointer (can't use struct
109
+ // with BDA on HLSL SPIRV)
110
+ static TransferRequest TransferRequest::newFromAddress (const uint64_t transferCmdAddr)
111
+ {
117
112
TransferRequest transferRequest;
118
113
transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr,8 );
119
114
transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr + sizeof (uint64_t),8 );
@@ -129,35 +124,31 @@ void main(uint32_t3 dispatchId)
129
124
transferRequest.srcIndexSizeLog2 = uint32_t (bitfieldType >> (32 + 3 + 24 + 1 ));
130
125
transferRequest.dstIndexSizeLog2 = uint32_t (bitfieldType >> (32 + 3 + 24 + 1 + 2 ));
131
126
132
- const uint dispatchSize = nbl::hlsl::device_capabilities_traits<device_capabilities>::maxOptimallyResidentWorkgroupInvocations;
127
+ return transferRequest;
128
+ }
129
+
130
+ template<typename device_capabilities>
131
+ void main (uint32_t3 dispatchId, const uint dispatchSize)
132
+ {
133
+ const uint propertyId = dispatchId.y;
134
+ const uint invocationIndex = dispatchId.x;
135
+
136
+ uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof (TransferRequest) * propertyId;
137
+ TransferRequest transferRequest = TransferRequest::newFromAddress (transferCmdAddr);
138
+
133
139
const bool fill = transferRequest.fill == 1 ;
134
140
135
- //uint64_t debugWriteAddr = transferRequest.dstAddr + sizeof(uint64_t) * 9 * propertyId;
136
- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 0, transferRequest.srcAddr,8);
137
- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 1, transferRequest.dstAddr,8);
138
- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 2, transferRequest.srcIndexAddr,8);
139
- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 3, transferRequest.dstIndexAddr,8);
140
- //uint64_t elementCount = uint64_t(transferRequest.elementCount32)
141
- // | uint64_t(transferRequest.elementCountExtra) << 32;
142
- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 4, elementCount,8);
143
- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 5, transferRequest.propertySize,4);
144
- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 6, transferRequest.fill,4);
145
- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 7, transferRequest.srcIndexSizeLog2,4);
146
- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 8, transferRequest.dstIndexSizeLog2,4);
147
- //vk::RawBufferStore<uint64_t>(transferRequest.dstAddr + sizeof(uint64_t) * invocationIndex, invocationIndex,8);
148
-
149
- if (fill) { TransferLoopPermutationFill<true > loop; loop.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
150
- else { TransferLoopPermutationFill<false > loop; loop.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
141
+ if (fill) { TransferLoopPermutationFill<true > loop; loop.copyLoop (globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
142
+ else { TransferLoopPermutationFill<false > loop; loop.copyLoop (globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
151
143
}
152
144
153
145
}
154
146
}
155
147
}
156
148
157
- // TODO: instead use some sort of replace function for getting optimal size?
158
- [numthreads (512 ,1 ,1 )]
149
+ [numthreads (nbl::hlsl::property_pools::OptimalDispatchSize,1 ,1 )]
159
150
void main (uint32_t3 dispatchId : SV_DispatchThreadID )
160
151
{
161
- nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId);
152
+ nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId, nbl::hlsl::property_pools::OptimalDispatchSize );
162
153
}
163
154
0 commit comments