Skip to content

Commit 99d80a7

Browse files
committed
PR review fixes
1 parent 1a0c998 commit 99d80a7

File tree

5 files changed

+57
-61
lines changed

5 files changed

+57
-61
lines changed

include/nbl/builtin/hlsl/property_pool/copy.comp.hlsl

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace hlsl
99
namespace property_pools
1010
{
1111

12-
[[vk::push_constant]] GlobalPushContants globals;
12+
[[vk::push_constant]] TransferDispatchInfo globals;
1313

1414
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2>
1515
struct TransferLoop
@@ -39,12 +39,12 @@ struct TransferLoop
3939
else if (SrcIndexSizeLog2 == 3) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped));
4040
}
4141

42-
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
42+
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
4343
{
4444
uint64_t elementCount = uint64_t(transferRequest.elementCount32)
4545
| uint64_t(transferRequest.elementCountExtra) << 32;
46-
uint64_t lastInvocation = min(elementCount, globals.endOffset);
47-
for (uint64_t invocationIndex = globals.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
46+
uint64_t lastInvocation = min(elementCount, dispatchInfo.endOffset);
47+
for (uint64_t invocationIndex = dispatchInfo.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
4848
{
4949
iteration(propertyId, transferRequest, invocationIndex);
5050
}
@@ -62,58 +62,53 @@ struct TransferLoop
6262
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
6363
struct TransferLoopPermutationSrcIndexSizeLog
6464
{
65-
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
65+
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
6666
{
67-
if (transferRequest.dstIndexSizeLog2 == 0) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68-
else if (transferRequest.dstIndexSizeLog2 == 1) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69-
else if (transferRequest.dstIndexSizeLog2 == 2) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70-
else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
67+
if (transferRequest.dstIndexSizeLog2 == 0) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68+
else if (transferRequest.dstIndexSizeLog2 == 1) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69+
else if (transferRequest.dstIndexSizeLog2 == 2) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70+
else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
7171
}
7272
};
7373

7474
template<bool Fill, bool SrcIndexIota, bool DstIndexIota>
7575
struct TransferLoopPermutationDstIota
7676
{
77-
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
77+
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
7878
{
79-
if (transferRequest.srcIndexSizeLog2 == 0) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80-
else if (transferRequest.srcIndexSizeLog2 == 1) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81-
else if (transferRequest.srcIndexSizeLog2 == 2) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82-
else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
79+
if (transferRequest.srcIndexSizeLog2 == 0) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80+
else if (transferRequest.srcIndexSizeLog2 == 1) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81+
else if (transferRequest.srcIndexSizeLog2 == 2) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82+
else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
8383
}
8484
};
8585

8686
template<bool Fill, bool SrcIndexIota>
8787
struct TransferLoopPermutationSrcIota
8888
{
89-
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
89+
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
9090
{
9191
bool dstIota = transferRequest.dstIndexAddr == 0;
92-
if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93-
else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
92+
if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93+
else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
9494
}
9595
};
9696

9797
template<bool Fill>
9898
struct TransferLoopPermutationFill
9999
{
100-
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
100+
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
101101
{
102102
bool srcIota = transferRequest.srcIndexAddr == 0;
103-
if (srcIota) { TransferLoopPermutationSrcIota<Fill, true> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104-
else { TransferLoopPermutationSrcIota<Fill, false> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
103+
if (srcIota) { TransferLoopPermutationSrcIota<Fill, true> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104+
else { TransferLoopPermutationSrcIota<Fill, false> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
105105
}
106106
};
107107

108-
template<typename device_capabilities>
109-
void main(uint32_t3 dispatchId)
110-
{
111-
const uint propertyId = dispatchId.y;
112-
const uint invocationIndex = dispatchId.x;
113-
114-
// Loading transfer request from the pointer (can't use struct
115-
// with BDA on HLSL SPIRV)
116-
uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof(TransferRequest) * propertyId;
108+
// Loading transfer request from the pointer (can't use struct
109+
// with BDA on HLSL SPIRV)
110+
static TransferRequest TransferRequest::newFromAddress(const uint64_t transferCmdAddr)
111+
{
117112
TransferRequest transferRequest;
118113
transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr,8);
119114
transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr + sizeof(uint64_t),8);
@@ -129,35 +124,31 @@ void main(uint32_t3 dispatchId)
129124
transferRequest.srcIndexSizeLog2 = uint32_t(bitfieldType >> (32 + 3 + 24 + 1));
130125
transferRequest.dstIndexSizeLog2 = uint32_t(bitfieldType >> (32 + 3 + 24 + 1 + 2));
131126

132-
const uint dispatchSize = nbl::hlsl::device_capabilities_traits<device_capabilities>::maxOptimallyResidentWorkgroupInvocations;
127+
return transferRequest;
128+
}
129+
130+
template<typename device_capabilities>
131+
void main(uint32_t3 dispatchId, const uint dispatchSize)
132+
{
133+
const uint propertyId = dispatchId.y;
134+
const uint invocationIndex = dispatchId.x;
135+
136+
uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof(TransferRequest) * propertyId;
137+
TransferRequest transferRequest = TransferRequest::newFromAddress(transferCmdAddr);
138+
133139
const bool fill = transferRequest.fill == 1;
134140

135-
//uint64_t debugWriteAddr = transferRequest.dstAddr + sizeof(uint64_t) * 9 * propertyId;
136-
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 0, transferRequest.srcAddr,8);
137-
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 1, transferRequest.dstAddr,8);
138-
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 2, transferRequest.srcIndexAddr,8);
139-
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 3, transferRequest.dstIndexAddr,8);
140-
//uint64_t elementCount = uint64_t(transferRequest.elementCount32)
141-
// | uint64_t(transferRequest.elementCountExtra) << 32;
142-
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 4, elementCount,8);
143-
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 5, transferRequest.propertySize,4);
144-
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 6, transferRequest.fill,4);
145-
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 7, transferRequest.srcIndexSizeLog2,4);
146-
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 8, transferRequest.dstIndexSizeLog2,4);
147-
//vk::RawBufferStore<uint64_t>(transferRequest.dstAddr + sizeof(uint64_t) * invocationIndex, invocationIndex,8);
148-
149-
if (fill) { TransferLoopPermutationFill<true> loop; loop.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); }
150-
else { TransferLoopPermutationFill<false> loop; loop.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); }
141+
if (fill) { TransferLoopPermutationFill<true> loop; loop.copyLoop(globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
142+
else { TransferLoopPermutationFill<false> loop; loop.copyLoop(globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
151143
}
152144

153145
}
154146
}
155147
}
156148

157-
// TODO: instead use some sort of replace function for getting optimal size?
158-
[numthreads(512,1,1)]
149+
[numthreads(nbl::hlsl::property_pools::OptimalDispatchSize,1,1)]
159150
void main(uint32_t3 dispatchId : SV_DispatchThreadID)
160151
{
161-
nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId);
152+
nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId, nbl::hlsl::property_pools::OptimalDispatchSize);
162153
}
163154

include/nbl/builtin/hlsl/property_pool/transfer.hlsl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ struct TransferRequest
3434
uint32_t fill: 1;
3535
uint32_t srcIndexSizeLog2 : 2;
3636
uint32_t dstIndexSizeLog2 : 2;
37+
38+
// Reads a TransferRequest from a BDA
39+
static TransferRequest newFromAddress(const uint64_t address);
3740
};
3841

39-
struct GlobalPushContants
42+
struct TransferDispatchInfo
4043
{
4144
// BDA address (GPU pointer) into the transfer commands buffer
4245
uint64_t transferCommandsAddress;
@@ -49,6 +52,9 @@ struct GlobalPushContants
4952

5053
NBL_CONSTEXPR uint32_t MaxPropertiesPerDispatch = 128;
5154

55+
// TODO: instead use some sort of replace function for getting optimal size?
56+
NBL_CONSTEXPR uint32_t OptimalDispatchSize = 256;
57+
5258
}
5359
}
5460
}

include/nbl/video/utilities/CPropertyPoolHandler.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
#include "nbl/video/utilities/IPropertyPool.h"
1414

1515
#include "glm/glm/glm.hpp"
16-
#include <nbl/builtin/hlsl/cpp_compat/matrix.hlsl>
17-
#include <nbl/builtin/hlsl/cpp_compat/vector.hlsl>
16+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
1817
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"
1918

2019
namespace nbl::video

include/nbl/video/utilities/IPropertyPool.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
#include "nbl/video/IGPUDescriptorSetLayout.h"
1313

1414
#include "glm/glm/glm.hpp"
15-
#include <nbl/builtin/hlsl/cpp_compat/matrix.hlsl>
16-
#include <nbl/builtin/hlsl/cpp_compat/vector.hlsl>
15+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
1716
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"
1817

1918
namespace nbl::video

src/nbl/video/utilities/CPropertyPoolHandler.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ CPropertyPoolHandler::CPropertyPoolHandler(core::smart_refctd_ptr<ILogicalDevice
2424
return shader;
2525
};
2626
auto shader = loadShader("../../../include/nbl/builtin/hlsl/property_pool/copy.comp.hlsl");
27-
const asset::SPushConstantRange baseDWORD = { asset::IShader::ESS_COMPUTE,0u,sizeof(nbl::hlsl::property_pools::GlobalPushContants) };
28-
auto layout = m_device->createPipelineLayout({ &baseDWORD,1u });
27+
const asset::SPushConstantRange transferInfoPushConstants = { asset::IShader::ESS_COMPUTE,0u,sizeof(nbl::hlsl::property_pools::TransferDispatchInfo) };
28+
auto layout = m_device->createPipelineLayout({ &transferInfoPushConstants,1u });
2929

3030
{
3131
video::IGPUComputePipeline::SCreationParams params = {};
@@ -93,7 +93,7 @@ bool CPropertyPoolHandler::transferProperties(
9393
IGPUCommandBuffer* const cmdbuf, //IGPUFence* const fence,
9494
const asset::SBufferBinding<video::IGPUBuffer>& scratch, const asset::SBufferBinding<video::IGPUBuffer>& addresses,
9595
const TransferRequest* const requestsBegin, const TransferRequest* const requestsEnd,
96-
system::logger_opt_ptr logger, const uint32_t baseDWORD, const uint32_t endDWORD
96+
system::logger_opt_ptr logger, const uint32_t baseOffsetBytes, const uint32_t endOffsetBytes
9797
)
9898
{
9999
if (requestsBegin==requestsEnd)
@@ -158,10 +158,11 @@ bool CPropertyPoolHandler::transferProperties(
158158

159159
cmdbuf->bindComputePipeline(m_pipeline.get());
160160

161-
nbl::hlsl::property_pools::GlobalPushContants pushConstants;
161+
nbl::hlsl::property_pools::TransferDispatchInfo pushConstants;
162162
{
163-
pushConstants.beginOffset = baseDWORD;
164-
pushConstants.endOffset = endDWORD;
163+
// TODO: Should the offset bytes be handled elsewhere?
164+
pushConstants.beginOffset = baseOffsetBytes;
165+
pushConstants.endOffset = endOffsetBytes;
165166
pushConstants.transferCommandsAddress = scratchBufferDeviceAddr;
166167
}
167168
assert(getAlignment(scratchBufferDeviceAddr) == 0);
@@ -172,7 +173,7 @@ bool CPropertyPoolHandler::transferProperties(
172173
{
173174
const auto& limits = m_device->getPhysicalDevice()->getLimits();
174175
const auto invocationCoarseness = limits.maxOptimallyResidentWorkgroupInvocations * requestsThisPass;
175-
cmdbuf->dispatch(limits.computeOptimalPersistentWorkgroupDispatchSize(maxElements,invocationCoarseness), requestsThisPass, 1u);
176+
cmdbuf->dispatch((maxElements - 1) / nbl::hlsl::property_pools::OptimalDispatchSize + 1, requestsThisPass, 1u);
176177
}
177178
// TODO: pipeline barrier
178179
}

0 commit comments

Comments
 (0)