2
2
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
3
3
#include "nbl/builtin/hlsl/property_pool/transfer.hlsl"
4
4
5
- // https://github.com/microsoft/DirectXShaderCompiler/issues/6144
6
- // template<typename capability_traits=nbl::hlsl::jit::device_capabilities_traits>
7
- // uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize() {
8
- // return uint32_t3(capability_traits::maxOptimallyResidentWorkgroupInvocations, 1, 1);
9
- // }
10
-
11
- [[numthreads (1 , 1 , 1 )]
12
- void main (uint32_t3 dispatchId : SV_DispatchThreadID )
13
- {
14
- nbl::hlsl::property_pool::main (dispatchId);
15
- }
16
-
17
5
namespace nbl
18
6
{
19
7
namespace hlsl
@@ -28,14 +16,14 @@ struct TransferLoop
28
16
{
29
17
void iteration (uint propertyId, uint64_t propertySize, uint64_t srcAddr, uint64_t dstAddr, uint invocationIndex)
30
18
{
31
- const uint srcOffset = uint64_t (invocationIndex) * (uint64_t (1 ) << SrcIndexSizeLog2) * propertySize;
32
- const uint dstOffset = uint64_t (invocationIndex) * (uint64_t (1 ) << DstIndexSizeLog2) * propertySize;
19
+ const uint64_t srcOffset = uint64_t (invocationIndex) * (uint64_t (1 ) << SrcIndexSizeLog2) * propertySize;
20
+ const uint64_t dstOffset = uint64_t (invocationIndex) * (uint64_t (1 ) << DstIndexSizeLog2) * propertySize;
33
21
34
- const uint srcIndexAddress = Fill ? srcAddr + srcOffset : srcAddr;
35
- const uint dstIndexAddress = Fill ? dstAddr + dstOffset : dstAddr;
22
+ const uint64_t srcIndexAddress = Fill ? srcAddr + srcOffset : srcAddr;
23
+ const uint64_t dstIndexAddress = Fill ? dstAddr + dstOffset : dstAddr;
36
24
37
- const uint srcAddressMapped = SrcIndexIota ? srcIndexAddress : vk::RawBufferLoad<uint64_t>(srcIndexAddress);
38
- const uint dstAddressMapped = DstIndexIota ? dstIndexAddress : vk::RawBufferLoad<uint64_t>(dstIndexAddress);
25
+ const uint64_t srcAddressMapped = SrcIndexIota ? srcIndexAddress : vk::RawBufferLoad<uint64_t>(srcIndexAddress);
26
+ const uint64_t dstAddressMapped = DstIndexIota ? dstIndexAddress : vk::RawBufferLoad<uint64_t>(dstIndexAddress);
39
27
40
28
if (SrcIndexSizeLog2 == 0 ) {} // we can't write individual bytes
41
29
else if (SrcIndexSizeLog2 == 1 ) vk::RawBufferStore<uint16_t>(dstAddressMapped, vk::RawBufferLoad<uint16_t>(srcAddressMapped));
@@ -51,21 +39,21 @@ struct TransferLoop
51
39
iteration (propertyId, transferRequest.propertySize, transferRequest.srcAddr, transferRequest.dstAddr, invocationIndex);
52
40
}
53
41
}
54
- };
42
+ };
55
43
56
44
// For creating permutations of the functions based on parameters that are constant over the transfer request
57
- // These branches should all be scalar, and because of how templates work , the loops shouldn't have any
45
+ // These branches should all be scalar, and because of how templates are compiled statically , the loops shouldn't have any
58
46
// branching within them
59
47
60
48
template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
61
49
struct TransferLoopPermutationSrcIndexSizeLog
62
50
{
63
51
void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
64
52
{
65
- if (transferRequest.dstIndexSizeLog2 == 0 ) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
66
- else if (transferRequest.dstIndexSizeLog2 == 1 ) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
67
- else if (transferRequest.dstIndexSizeLog2 == 2 ) TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
68
- else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
53
+ if (transferRequest.dstIndexSizeLog2 == 0 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
54
+ else if (transferRequest.dstIndexSizeLog2 == 1 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
55
+ else if (transferRequest.dstIndexSizeLog2 == 2 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
56
+ else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69
57
}
70
58
};
71
59
@@ -74,10 +62,10 @@ struct TransferLoopPermutationDstIota
74
62
{
75
63
void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
76
64
{
77
- if (transferRequest.srcIndexSizeLog2 == 0 ) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
78
- else if (transferRequest.srcIndexSizeLog2 == 1 ) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
79
- else if (transferRequest.srcIndexSizeLog2 == 2 ) TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
80
- else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
65
+ if (transferRequest.srcIndexSizeLog2 == 0 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
66
+ else if (transferRequest.srcIndexSizeLog2 == 1 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
67
+ else if (transferRequest.srcIndexSizeLog2 == 2 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68
+ else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81
69
}
82
70
};
83
71
@@ -87,8 +75,8 @@ struct TransferLoopPermutationSrcIota
87
75
void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
88
76
{
89
77
bool dstIota = transferRequest.dstAddr == 0 ;
90
- if (dstIota) TransferLoopPermutationDstIota<Fill, SrcIndexIota, true >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
91
- else TransferLoopPermutationDstIota<Fill, SrcIndexIota, false >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
78
+ if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
79
+ else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
92
80
}
93
81
};
94
82
@@ -98,19 +86,20 @@ struct TransferLoopPermutationFill
98
86
void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
99
87
{
100
88
bool srcIota = transferRequest.srcAddr == 0 ;
101
- if (srcIota) TransferLoopPermutationSrcIota<Fill, true >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
102
- else TransferLoopPermutationSrcIota<Fill, false >.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize);
89
+ if (srcIota) { TransferLoopPermutationSrcIota<Fill, true > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
90
+ else { TransferLoopPermutationSrcIota<Fill, false > loop; loop .copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
103
91
}
104
92
};
105
93
94
+ template<typename device_capabilities>
106
95
void main (uint32_t3 dispatchId)
107
96
{
108
97
const uint propertyId = dispatchId.y;
109
98
const uint invocationIndex = dispatchId.x;
110
99
111
100
// Loading transfer request from the pointer (can't use struct
112
101
// with BDA on HLSL SPIRV)
113
- const TransferRequest transferRequest;
102
+ TransferRequest transferRequest;
114
103
transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress);
115
104
transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof (uint64_t));
116
105
transferRequest.srcIndexAddr = vk::RawBufferLoad<uint64_t>(globals.transferCommandsAddress + sizeof (uint64_t) * 2 );
@@ -124,13 +113,20 @@ void main(uint32_t3 dispatchId)
124
113
transferRequest.srcIndexSizeLog2 = bitfieldType >> (35 + 24 + 1 );
125
114
transferRequest.dstIndexSizeLog2 = bitfieldType >> (35 + 24 + 1 + 2 );
126
115
127
- const uint dispatchSize = capability_traits ::maxOptimallyResidentWorkgroupInvocations;
116
+ const uint dispatchSize = nbl::hlsl::device_capabilities_traits<device_capabilities> ::maxOptimallyResidentWorkgroupInvocations;
128
117
const bool fill = transferRequest.fill == 1 ;
129
118
130
- if (fill) TransferLoopPermutationFill<true >.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize);
131
- else TransferLoopPermutationFill<false >.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize);
119
+ if (fill) { TransferLoopPermutationFill<true > loop; loop .copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
120
+ else { TransferLoopPermutationFill<false > loop; loop .copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
132
121
}
133
122
134
123
}
135
124
}
136
125
}
126
+
127
+ [numthreads (1 ,1 ,1 )]
128
+ void main (uint32_t3 dispatchId : SV_DispatchThreadID )
129
+ {
130
+ nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId);
131
+ }
132
+
0 commit comments