Skip to content

Work on property pool HLSL impl #649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 40 additions & 49 deletions include/nbl/builtin/hlsl/property_pool/copy.comp.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace hlsl
namespace property_pools
{

[[vk::push_constant]] GlobalPushContants globals;
[[vk::push_constant]] TransferDispatchInfo globals;

template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2>
struct TransferLoop
Expand Down Expand Up @@ -39,12 +39,12 @@ struct TransferLoop
else if (SrcIndexSizeLog2 == 3) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped));
}

void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
{
uint64_t elementCount = uint64_t(transferRequest.elementCount32)
| uint64_t(transferRequest.elementCountExtra) << 32;
uint64_t lastInvocation = min(elementCount, globals.endOffset);
for (uint64_t invocationIndex = globals.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
uint64_t lastInvocation = min(elementCount, dispatchInfo.endOffset);
for (uint64_t invocationIndex = dispatchInfo.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
{
iteration(propertyId, transferRequest, invocationIndex);
}
Expand All @@ -62,58 +62,53 @@ struct TransferLoop
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
struct TransferLoopPermutationSrcIndexSizeLog
{
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
{
if (transferRequest.dstIndexSizeLog2 == 0) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.dstIndexSizeLog2 == 1) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.dstIndexSizeLog2 == 2) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
if (transferRequest.dstIndexSizeLog2 == 0) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.dstIndexSizeLog2 == 1) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.dstIndexSizeLog2 == 2) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
}
};

template<bool Fill, bool SrcIndexIota, bool DstIndexIota>
struct TransferLoopPermutationDstIota
{
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
{
if (transferRequest.srcIndexSizeLog2 == 0) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.srcIndexSizeLog2 == 1) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.srcIndexSizeLog2 == 2) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
if (transferRequest.srcIndexSizeLog2 == 0) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.srcIndexSizeLog2 == 1) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else if (transferRequest.srcIndexSizeLog2 == 2) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
}
};

template<bool Fill, bool SrcIndexIota>
struct TransferLoopPermutationSrcIota
{
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
{
bool dstIota = transferRequest.dstIndexAddr == 0;
if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
}
};

template<bool Fill>
struct TransferLoopPermutationFill
{
void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
void copyLoop(NBL_CONST_REF_ARG(TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
{
bool srcIota = transferRequest.srcIndexAddr == 0;
if (srcIota) { TransferLoopPermutationSrcIota<Fill, true> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationSrcIota<Fill, false> loop; loop.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
if (srcIota) { TransferLoopPermutationSrcIota<Fill, true> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationSrcIota<Fill, false> loop; loop.copyLoop(dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
}
};

template<typename device_capabilities>
void main(uint32_t3 dispatchId)
{
const uint propertyId = dispatchId.y;
const uint invocationIndex = dispatchId.x;

// Loading transfer request from the pointer (can't use struct
// with BDA on HLSL SPIRV)
uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof(TransferRequest) * propertyId;
// Loading transfer request from the pointer (can't use struct
// with BDA on HLSL SPIRV)
static TransferRequest TransferRequest::newFromAddress(const uint64_t transferCmdAddr)
{
Comment on lines +121 to +124

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep it with the struct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The struct is shared with c++ code, so i wouldn't be able to use vk::rawbufferread; I could take the 64 bit value though

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use #ifndef __HLSL_VERSION in the impl of the method

TransferRequest transferRequest;
transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr,8);
transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr + sizeof(uint64_t),8);
Expand All @@ -129,35 +124,31 @@ void main(uint32_t3 dispatchId)
transferRequest.srcIndexSizeLog2 = uint32_t(bitfieldType >> (32 + 3 + 24 + 1));
transferRequest.dstIndexSizeLog2 = uint32_t(bitfieldType >> (32 + 3 + 24 + 1 + 2));

const uint dispatchSize = nbl::hlsl::device_capabilities_traits<device_capabilities>::maxOptimallyResidentWorkgroupInvocations;
return transferRequest;
}

template<typename device_capabilities>
void main(uint32_t3 dispatchId, const uint dispatchSize)
{
const uint propertyId = dispatchId.y;
const uint invocationIndex = dispatchId.x;

uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof(TransferRequest) * propertyId;
TransferRequest transferRequest = TransferRequest::newFromAddress(transferCmdAddr);

const bool fill = transferRequest.fill == 1;

//uint64_t debugWriteAddr = transferRequest.dstAddr + sizeof(uint64_t) * 9 * propertyId;
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 0, transferRequest.srcAddr,8);
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 1, transferRequest.dstAddr,8);
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 2, transferRequest.srcIndexAddr,8);
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 3, transferRequest.dstIndexAddr,8);
//uint64_t elementCount = uint64_t(transferRequest.elementCount32)
// | uint64_t(transferRequest.elementCountExtra) << 32;
//vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 4, elementCount,8);
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 5, transferRequest.propertySize,4);
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 6, transferRequest.fill,4);
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 7, transferRequest.srcIndexSizeLog2,4);
//vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 8, transferRequest.dstIndexSizeLog2,4);
//vk::RawBufferStore<uint64_t>(transferRequest.dstAddr + sizeof(uint64_t) * invocationIndex, invocationIndex,8);

if (fill) { TransferLoopPermutationFill<true> loop; loop.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationFill<false> loop; loop.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); }
if (fill) { TransferLoopPermutationFill<true> loop; loop.copyLoop(globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
else { TransferLoopPermutationFill<false> loop; loop.copyLoop(globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
}

}
}
}

// TODO: instead use some sort of replace function for getting optimal size?
[numthreads(512,1,1)]
[numthreads(nbl::hlsl::property_pools::OptimalDispatchSize,1,1)]
void main(uint32_t3 dispatchId : SV_DispatchThreadID)
{
nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId);
nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId, nbl::hlsl::property_pools::OptimalDispatchSize);
}

8 changes: 7 additions & 1 deletion include/nbl/builtin/hlsl/property_pool/transfer.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ struct TransferRequest
uint32_t fill: 1;
uint32_t srcIndexSizeLog2 : 2;
uint32_t dstIndexSizeLog2 : 2;

// Reads a TransferRequest from a BDA
static TransferRequest newFromAddress(const uint64_t address);
};

struct GlobalPushContants
struct TransferDispatchInfo
{
// BDA address (GPU pointer) into the transfer commands buffer
uint64_t transferCommandsAddress;
Expand All @@ -49,6 +52,9 @@ struct GlobalPushContants

NBL_CONSTEXPR uint32_t MaxPropertiesPerDispatch = 128;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any reason to keep this around anymore?


// TODO: instead use some sort of replace function for getting optimal size?
NBL_CONSTEXPR uint32_t OptimalDispatchSize = 256;
Comment on lines +55 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use the device JIT to query the max compute dispatch size, I'd round it down to nearest PoT though, so the divisions aren't expensive


}
}
}
Expand Down
3 changes: 1 addition & 2 deletions include/nbl/video/utilities/CPropertyPoolHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
#include "nbl/video/utilities/IPropertyPool.h"

#include "glm/glm/glm.hpp"
#include <nbl/builtin/hlsl/cpp_compat/matrix.hlsl>
#include <nbl/builtin/hlsl/cpp_compat/vector.hlsl>
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"

namespace nbl::video
Expand Down
10 changes: 5 additions & 5 deletions include/nbl/video/utilities/IPropertyPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#include "nbl/video/IGPUDescriptorSetLayout.h"

#include "glm/glm/glm.hpp"
#include <nbl/builtin/hlsl/cpp_compat/matrix.hlsl>
#include <nbl/builtin/hlsl/cpp_compat/vector.hlsl>
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"

namespace nbl::video
Expand All @@ -27,6 +26,7 @@ class NBL_API2 IPropertyPool : public core::IReferenceCounted
using PropertyAddressAllocator = core::PoolAddressAllocatorST<uint32_t>;

static inline constexpr uint64_t invalid = 0;
using value_type = PropertyAddressAllocator::size_type;
//
virtual const asset::SBufferRange<IGPUBuffer>& getPropertyMemoryBlock(uint32_t ix) const =0;

Expand All @@ -38,19 +38,19 @@ class NBL_API2 IPropertyPool : public core::IReferenceCounted
inline bool isContiguous() const {return m_indexToAddr;}

//
inline uint64_t getAllocated() const
inline value_type getAllocated() const
{
return indexAllocator.get_allocated_size();
}

//
inline uint64_t getFree() const
inline value_type getFree() const
{
return indexAllocator.get_free_size();
}

//
inline uint64_t getCapacity() const
inline value_type getCapacity() const
{
// special case allows us to use `get_total_size`, because the pool allocator has no added offsets
return indexAllocator.get_total_size();
Expand Down
41 changes: 31 additions & 10 deletions src/nbl/video/utilities/CPropertyPoolHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ CPropertyPoolHandler::CPropertyPoolHandler(core::smart_refctd_ptr<ILogicalDevice
return shader;
};
auto shader = loadShader("../../../include/nbl/builtin/hlsl/property_pool/copy.comp.hlsl");
const asset::SPushConstantRange baseDWORD = { asset::IShader::ESS_COMPUTE,0u,sizeof(nbl::hlsl::property_pools::GlobalPushContants) };
auto layout = m_device->createPipelineLayout({ &baseDWORD,1u });
const asset::SPushConstantRange transferInfoPushConstants = { asset::IShader::ESS_COMPUTE,0u,sizeof(nbl::hlsl::property_pools::TransferDispatchInfo) };
auto layout = m_device->createPipelineLayout({ &transferInfoPushConstants,1u });

{
video::IGPUComputePipeline::SCreationParams params = {};
Expand Down Expand Up @@ -93,7 +93,7 @@ bool CPropertyPoolHandler::transferProperties(
IGPUCommandBuffer* const cmdbuf, //IGPUFence* const fence,
const asset::SBufferBinding<video::IGPUBuffer>& scratch, const asset::SBufferBinding<video::IGPUBuffer>& addresses,
const TransferRequest* const requestsBegin, const TransferRequest* const requestsEnd,
system::logger_opt_ptr logger, const uint32_t baseDWORD, const uint32_t endDWORD
system::logger_opt_ptr logger, const uint32_t baseOffsetBytes, const uint32_t endOffsetBytes
)
{
if (requestsBegin==requestsEnd)
Expand Down Expand Up @@ -138,8 +138,28 @@ bool CPropertyPoolHandler::transferProperties(
transferRequest.fill = 0; // TODO
transferRequest.srcIndexSizeLog2 = 1u; // TODO
transferRequest.dstIndexSizeLog2 = 1u; // TODO
assert(getAlignment(transferRequest.srcAddr) == 0);
assert(getAlignment(transferRequest.dstAddr) == 0);
if (getAlignment(transferRequest.srcAddr) != 0)
{
logger.log("CPropertyPoolHandler: memblock.buffer BDA address %I64i is not aligned to 8 byte (64 bit)",system::ILogger::ELL_ERROR, transferRequest.srcAddr);
}
if (getAlignment(transferRequest.dstAddr) != 0)
{
logger.log("CPropertyPoolHandler: buffer.buffer BDA address %I64i is not aligned to 8 byte (64 bit)",system::ILogger::ELL_ERROR, transferRequest.dstAddr);
}
if (getAlignment(transferRequest.propertySize) != 0)
{
logger.log("CPropertyPoolHandler: propertySize %i is not aligned to 8 byte (64 bit)",system::ILogger::ELL_ERROR, srcRequest->elementSize);
}
if (transferRequest.srcIndexSizeLog2 < 1 || transferRequest.srcIndexSizeLog2 > 3)
{
auto srcIndexSizeLog2 = transferRequest.srcIndexSizeLog2;
logger.log("CPropertyPoolHandler: srcIndexSizeLog2 %i (%i bit values) are unsupported",system::ILogger::ELL_ERROR, srcIndexSizeLog2, (1 << transferRequest.srcIndexSizeLog2) * sizeof(uint8_t));
}
if (transferRequest.dstIndexSizeLog2 < 1 || transferRequest.dstIndexSizeLog2 > 3)
{
auto dstIndexSizeLog2 = transferRequest.dstIndexSizeLog2;
logger.log("CPropertyPoolHandler: dstIndexSizeLog2 %i (%i bit values) are unsupported",system::ILogger::ELL_ERROR, dstIndexSizeLog2, (1 << transferRequest.srcIndexSizeLog2) * sizeof(uint8_t));
}

maxElements = core::max<uint64_t>(maxElements, srcRequest->elementCount);
}
Expand All @@ -158,11 +178,12 @@ bool CPropertyPoolHandler::transferProperties(

cmdbuf->bindComputePipeline(m_pipeline.get());

nbl::hlsl::property_pools::GlobalPushContants pushConstants;
nbl::hlsl::property_pools::TransferDispatchInfo pushConstants;
{
pushConstants.beginOffset = baseDWORD;
pushConstants.endOffset = endDWORD;
pushConstants.transferCommandsAddress = scratchBufferDeviceAddr;
// TODO: Should the offset bytes be handled elsewhere?
pushConstants.beginOffset = baseOffsetBytes;
pushConstants.endOffset = endOffsetBytes;
pushConstants.transferCommandsAddress = scratchBufferDeviceAddr + transferPassRequestsIndex * sizeof(TransferRequest);
}
assert(getAlignment(scratchBufferDeviceAddr) == 0);
assert(getAlignment(sizeof(TransferRequest)) == 0);
Expand All @@ -172,7 +193,7 @@ bool CPropertyPoolHandler::transferProperties(
{
const auto& limits = m_device->getPhysicalDevice()->getLimits();
const auto invocationCoarseness = limits.maxOptimallyResidentWorkgroupInvocations * requestsThisPass;
cmdbuf->dispatch(limits.computeOptimalPersistentWorkgroupDispatchSize(maxElements,invocationCoarseness), requestsThisPass, 1u);
cmdbuf->dispatch((maxElements - 1) / nbl::hlsl::property_pools::OptimalDispatchSize + 1, requestsThisPass, 1u);
}
// TODO: pipeline barrier
}
Expand Down