|
| 1 | +#include "nbl/builtin/hlsl/jit/device_capabilities.hlsl" |
| 2 | +#include "nbl/builtin/hlsl/property_pool/transfer.hlsl" |
| 3 | + |
| 4 | +namespace nbl |
| 5 | +{ |
| 6 | +namespace hlsl |
| 7 | +{ |
| 8 | +namespace property_pools |
| 9 | +{ |
| 10 | +// https://github.com/microsoft/DirectXShaderCompiler/issues/6144 |
| 11 | +template<typename capability_traits=nbl::hlsl::jit::device_capabilities_traits> |
| 12 | +uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize() { |
| 13 | + return uint32_t3(capability_traits::maxOptimallyResidentWorkgroupInvocations, 1, 1); |
| 14 | +} |
| 15 | + |
| 16 | +[[vk::push_constant]] GlobalPushContants globals; |
| 17 | + |
| 18 | +template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2> |
| 19 | +struct TransferLoop |
| 20 | +{ |
| 21 | + void iteration(uint propertyId, uint64_t propertySize, uint64_t srcAddr, uint64_t dstAddr, uint invocationIndex) |
| 22 | + { |
| 23 | + const uint srcOffset = uint64_t(invocationIndex) * (uint64_t(1) << SrcIndexSizeLog2) * propertySize; |
| 24 | + const uint dstOffset = uint64_t(invocationIndex) * (uint64_t(1) << DstIndexSizeLog2) * propertySize; |
| 25 | + |
| 26 | + const uint srcIndexAddress = Fill ? srcAddr + srcOffset : srcAddr; |
| 27 | + const uint dstIndexAddress = Fill ? dstAddr + dstOffset : dstAddr; |
| 28 | + |
| 29 | + const uint srcAddressMapped = SrcIndexIota ? srcIndexAddress : vk::RawBufferLoad<uint64_t>(srcIndexAddress); |
| 30 | + const uint dstAddressMapped = DstIndexIota ? dstIndexAddress : vk::RawBufferLoad<uint64_t>(dstIndexAddress); |
| 31 | + |
| 32 | + if (SrcIndexSizeLog2 == 0) {} // we can't write individual bytes |
| 33 | + else if (SrcIndexSizeLog2 == 1) vk::RawBufferStore<uint16_t>(dstAddressMapped, vk::RawBufferLoad<uint16_t>(srcAddressMapped)); |
| 34 | + else if (SrcIndexSizeLog2 == 2) vk::RawBufferStore<uint32_t>(dstAddressMapped, vk::RawBufferLoad<uint32_t>(srcAddressMapped)); |
| 35 | + else if (SrcIndexSizeLog2 == 3) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped)); |
| 36 | + } |
| 37 | + |
| 38 | + void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize) |
| 39 | + { |
| 40 | + uint lastInvocation = min(transferRequest.elementCount, gloabls.endOffset); |
| 41 | + for (uint invocationIndex = globals.beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize) |
| 42 | + { |
| 43 | + iteration(propertyId, transferRequest.propertySize, transferRequest.srcAddr, transferRequest.dstAddr, invocationIndex); |
| 44 | + } |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// For creating permutations of the functions based on parameters that are constant over the transfer request |
| 49 | +// These branches should all be scalar, and because of how templates work, the loops shouldn't have any |
| 50 | +// branching within them |
| 51 | + |
| 52 | +template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2> |
| 53 | +struct TransferLoopPermutationSrcIndexSizeLog |
| 54 | +{ |
| 55 | + void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize) |
| 56 | + { |
| 57 | + if (transferRequest.dstIndexSizeLog2 == 0) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 58 | + else if (transferRequest.dstIndexSizeLog2 == 1) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 59 | + else if (transferRequest.dstIndexSizeLog2 == 2) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 60 | + else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +template<bool Fill, bool SrcIndexIota, bool DstIndexIota> |
| 65 | +struct TransferLoopPermutationDstIota |
| 66 | +{ |
| 67 | + void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize) |
| 68 | + { |
| 69 | + if (transferRequest.srcIndexSizeLog2 == 0) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 70 | + else if (transferRequest.srcIndexSizeLog2 == 1) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 71 | + else if (transferRequest.srcIndexSizeLog2 == 2) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 72 | + else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +template<bool Fill, bool SrcIndexIota> |
| 77 | +struct TransferLoopPermutationSrcIota |
| 78 | +{ |
| 79 | + void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize) |
| 80 | + { |
| 81 | + bool dstIota = transferRequest.dstAddr == 0; |
| 82 | + if (dstIota) TransferLoopPermutationDstIota<Fill, SrcIndexIota, true>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 83 | + else TransferLoopPermutationDstIota<Fill, SrcIndexIota, false>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +template<bool Fill> |
| 88 | +struct TransferLoopPermutationFill |
| 89 | +{ |
| 90 | + void copyLoop(uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize) |
| 91 | + { |
| 92 | + bool srcIota = transferRequest.srcAddr == 0; |
| 93 | + if (srcIota) TransferLoopPermutationSrcIota<Fill, true>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 94 | + else TransferLoopPermutationSrcIota<Fill, false>.copyLoop(baseInvocationIndex, propertyId, transferRequest, dispatchSize); |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +void main(uint32_t3 dispatchId : SV_DispatchThreadID) |
| 99 | +{ |
| 100 | + const uint propertyId = dispatchId.y; |
| 101 | + const uint invocationIndex = dispatchId.x; |
| 102 | + |
| 103 | + // Loading transfer request from the pointer (can't use struct |
| 104 | + // with BDA on HLSL SPIRV) |
| 105 | + const TransferRequest transferRequest; |
| 106 | + transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress); |
| 107 | + transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t)); |
| 108 | + transferRequest.srcIndexAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 2); |
| 109 | + transferRequest.dstIndexAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 3); |
| 110 | + // TODO: These are all part of the same bitfield and shoulbe read with a single RawBufferLoad |
| 111 | + transferRequest.elementCount = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 4); |
| 112 | + transferRequest.propertySize = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 5); |
| 113 | + transferRequest.fill = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 6); |
| 114 | + transferRequest.srcIndexSizeLog2 = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 7); |
| 115 | + transferRequest.dstIndexSizeLog2 = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof(uint64_t) * 8); |
| 116 | + |
| 117 | + const uint dispatchSize = capability_traits::maxOptimallyResidentWorkgroupInvocations; |
| 118 | + const bool fill = transferRequest.fill == 1; |
| 119 | + |
| 120 | + if (fill) TransferLoopPermutationFill<true>.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); |
| 121 | + else TransferLoopPermutationFill<false>.copyLoop(invocationIndex, propertyId, transferRequest, dispatchSize); |
| 122 | +} |
| 123 | + |
| 124 | +} |
| 125 | +} |
| 126 | +} |
| 127 | + |
0 commit comments