@@ -45,143 +45,154 @@ void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) {
45
45
}
46
46
47
47
void AddValueToExpandedPayload (OP *HlslOP, llvm::IRBuilder<> &B,
48
- ExpandedStruct &expanded,
49
48
AllocaInst *NewStructAlloca,
50
49
unsigned int expandedValueIndex, Value *value) {
51
50
Constant *Zero32Arg = HlslOP->GetU32Const (0 );
52
51
SmallVector<Value *, 2 > IndexToAppendedValue;
53
52
IndexToAppendedValue.push_back (Zero32Arg);
54
53
IndexToAppendedValue.push_back (HlslOP->GetU32Const (expandedValueIndex));
55
54
auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP (
56
- expanded. ExpandedPayloadStructType , NewStructAlloca, IndexToAppendedValue,
55
+ NewStructAlloca, IndexToAppendedValue,
57
56
" PointerToEmbeddedNewValue" + std::to_string (expandedValueIndex));
58
57
B.CreateStore (value, PointerToEmbeddedNewValue);
59
58
}
60
59
61
- bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule (Module &M) {
60
+ void CopyAggregate (IRBuilder<> &B, Type *Ty, Value *Source, Value *Dest,
61
+ ArrayRef<Value *> GEPIndices) {
62
+ if (StructType *ST = dyn_cast<StructType>(Ty)) {
63
+ SmallVector<Value *, 16 > StructIndices;
64
+ StructIndices.append (GEPIndices.begin (), GEPIndices.end ());
65
+ StructIndices.push_back (nullptr );
66
+ for (unsigned j = 0 ; j < ST->getNumElements (); ++j) {
67
+ StructIndices.back () = B.getInt32 (j);
68
+ CopyAggregate (B, ST->getElementType (j), Source, Dest, StructIndices);
69
+ }
70
+ } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
71
+ SmallVector<Value *, 16 > StructIndices;
72
+ StructIndices.append (GEPIndices.begin (), GEPIndices.end ());
73
+ StructIndices.push_back (nullptr );
74
+ for (unsigned j = 0 ; j < AT->getNumElements (); ++j) {
75
+ StructIndices.back () = B.getInt32 (j);
76
+ CopyAggregate (B, AT->getArrayElementType (), Source, Dest, StructIndices);
77
+ }
78
+ } else {
79
+ auto *SourceGEP = B.CreateGEP (Source, GEPIndices, " CopyStructSourceGEP" );
80
+ Value *Val = B.CreateLoad (SourceGEP, " CopyStructLoad" );
81
+ auto *DestGEP = B.CreateGEP (Dest, GEPIndices, " CopyStructDestGEP" );
82
+ B.CreateStore (Val, DestGEP, " CopyStructStore" );
83
+ }
84
+ }
62
85
86
+ bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule (Module &M) {
63
87
DxilModule &DM = M.GetOrCreateDxilModule ();
64
88
LLVMContext &Ctx = M.getContext ();
65
89
OP *HlslOP = DM.GetOP ();
66
-
67
- Type *OriginalPayloadStructPointerType = nullptr ;
68
- Type *OriginalPayloadStructType = nullptr ;
69
- ExpandedStruct expanded;
70
90
llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction (DM);
71
91
for (inst_iterator I = inst_begin (entryFunction), E = inst_end (entryFunction);
72
92
I != E; ++I) {
73
- if (auto *Instr = llvm::cast<Instruction>(&*I)) {
74
- if (hlsl::OP::IsDxilOpFuncCallInst (Instr,
75
- hlsl::OP::OpCode::DispatchMesh)) {
76
- DxilInst_DispatchMesh DispatchMesh (Instr);
77
- OriginalPayloadStructPointerType =
78
- DispatchMesh.get_payload ()->getType ();
79
- OriginalPayloadStructType =
80
- OriginalPayloadStructPointerType->getPointerElementType ();
81
- expanded = ExpandStructType (Ctx, OriginalPayloadStructType);
82
- }
83
- }
84
- }
85
-
86
- AllocaInst *OldStructAlloca = nullptr ;
87
- AllocaInst *NewStructAlloca = nullptr ;
88
- std::vector<AllocaInst *> allocasOfPayloadType;
89
- for (inst_iterator I = inst_begin (entryFunction), E = inst_end (entryFunction);
90
- I != E; ++I) {
91
- auto *Inst = &*I;
92
- if (llvm::isa<AllocaInst>(Inst)) {
93
- auto *Alloca = llvm::cast<AllocaInst>(Inst);
94
- if (Alloca->getType () == OriginalPayloadStructPointerType) {
95
- allocasOfPayloadType.push_back (Alloca);
96
- }
93
+ if (hlsl::OP::IsDxilOpFuncCallInst (&*I, hlsl::OP::OpCode::DispatchMesh)) {
94
+ DxilInst_DispatchMesh DispatchMesh (&*I);
95
+ Type *OriginalPayloadStructPointerType =
96
+ DispatchMesh.get_payload ()->getType ();
97
+ Type *OriginalPayloadStructType =
98
+ OriginalPayloadStructPointerType->getPointerElementType ();
99
+ ExpandedStruct expanded =
100
+ ExpandStructType (Ctx, OriginalPayloadStructType);
101
+
102
+ llvm::IRBuilder<> B (&*I);
103
+
104
+ auto *NewStructAlloca =
105
+ B.CreateAlloca (expanded.ExpandedPayloadStructType ,
106
+ HlslOP->GetU32Const (1 ), " NewPayload" );
107
+ NewStructAlloca->setAlignment (4 );
108
+ auto PayloadType =
109
+ llvm::dyn_cast<PointerType>(DispatchMesh.get_payload ()->getType ());
110
+ SmallVector<Value *, 16 > GEPIndices;
111
+ GEPIndices.push_back (B.getInt32 (0 ));
112
+ CopyAggregate (B, PayloadType->getPointerElementType (),
113
+ DispatchMesh.get_payload (), NewStructAlloca, GEPIndices);
114
+
115
+ Constant *Zero32Arg = HlslOP->GetU32Const (0 );
116
+ Constant *One32Arg = HlslOP->GetU32Const (1 );
117
+ Constant *Two32Arg = HlslOP->GetU32Const (2 );
118
+
119
+ auto GroupIdFunc =
120
+ HlslOP->GetOpFunc (DXIL::OpCode::GroupId, Type::getInt32Ty (Ctx));
121
+ Constant *GroupIdOpcode =
122
+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::GroupId);
123
+ auto *GroupIdX =
124
+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, Zero32Arg}, " GroupIdX" );
125
+ auto *GroupIdY =
126
+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, One32Arg}, " GroupIdY" );
127
+ auto *GroupIdZ =
128
+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, Two32Arg}, " GroupIdZ" );
129
+
130
+ // FlatGroupID = z + y*numZ + x*numY*numZ
131
+ // Where x,y,z are the group ID components, and numZ and numY are the
132
+ // corresponding AS group-count arguments to the DispatchMesh Direct3D API
133
+ auto *GroupYxNumZ = B.CreateMul (
134
+ GroupIdY, HlslOP->GetU32Const (m_DispatchArgumentZ), " GroupYxNumZ" );
135
+ auto *FlatGroupNumZY =
136
+ B.CreateAdd (GroupIdZ, GroupYxNumZ, " FlatGroupNumZY" );
137
+ auto *GroupXxNumYZ = B.CreateMul (
138
+ GroupIdX,
139
+ HlslOP->GetU32Const (m_DispatchArgumentY * m_DispatchArgumentZ),
140
+ " GroupXxNumYZ" );
141
+ auto *FlatGroupID =
142
+ B.CreateAdd (GroupXxNumYZ, FlatGroupNumZY, " FlatGroupID" );
143
+
144
+ // The ultimate goal is a single unique thread ID for this AS thread.
145
+ // So take the flat group number, multiply it by the number of
146
+ // threads per group...
147
+ auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul (
148
+ FlatGroupID,
149
+ HlslOP->GetU32Const (DM.GetNumThreads (0 ) * DM.GetNumThreads (1 ) *
150
+ DM.GetNumThreads (2 )),
151
+ " FlatGroupIDWithSpaceForThreadInGroupId" );
152
+
153
+ auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc (
154
+ DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty (Ctx));
155
+ Constant *FlattenedThreadIdInGroupOpcode =
156
+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::FlattenedThreadIdInGroup);
157
+ auto FlatThreadIdInGroup = B.CreateCall (FlattenedThreadIdInGroupFunc,
158
+ {FlattenedThreadIdInGroupOpcode},
159
+ " FlattenedThreadIdInGroup" );
160
+
161
+ // ...and add the flat thread id:
162
+ auto *FlatId = B.CreateAdd (FlatGroupIDWithSpaceForThreadInGroupId,
163
+ FlatThreadIdInGroup, " FlatId" );
164
+
165
+ AddValueToExpandedPayload (
166
+ HlslOP, B, NewStructAlloca,
167
+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 3 ,
168
+ FlatId);
169
+ AddValueToExpandedPayload (
170
+ HlslOP, B, NewStructAlloca,
171
+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 2 ,
172
+ DispatchMesh.get_threadGroupCountY ());
173
+ AddValueToExpandedPayload (
174
+ HlslOP, B, NewStructAlloca,
175
+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 1 ,
176
+ DispatchMesh.get_threadGroupCountZ ());
177
+
178
+ auto DispatchMeshFn = HlslOP->GetOpFunc (
179
+ DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType );
180
+ Constant *DispatchMeshOpcode =
181
+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::DispatchMesh);
182
+ B.CreateCall (DispatchMeshFn,
183
+ {DispatchMeshOpcode, DispatchMesh.get_threadGroupCountX (),
184
+ DispatchMesh.get_threadGroupCountY (),
185
+ DispatchMesh.get_threadGroupCountZ (), NewStructAlloca});
186
+ I->removeFromParent ();
187
+ delete &*I;
188
+ // Validation requires exactly one DispatchMesh in an AS, so we can exit
189
+ // after the first one:
190
+ DM.ReEmitDxilResources ();
191
+ return true ;
97
192
}
98
193
}
99
- for (auto &Alloca : allocasOfPayloadType) {
100
- OldStructAlloca = Alloca;
101
- llvm::IRBuilder<> B (Alloca->getContext ());
102
- NewStructAlloca = B.CreateAlloca (expanded.ExpandedPayloadStructType ,
103
- HlslOP->GetU32Const (1 ), " NewPayload" );
104
- NewStructAlloca->setAlignment (Alloca->getAlignment ());
105
- NewStructAlloca->insertAfter (Alloca);
106
-
107
- ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction (
108
- Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType );
109
- }
110
-
111
- auto F = HlslOP->GetOpFunc (DXIL::OpCode::DispatchMesh,
112
- expanded.ExpandedPayloadStructPtrType );
113
- for (auto FI = F->user_begin (); FI != F->user_end ();) {
114
- auto *FunctionUser = *FI++;
115
- auto *UserInstruction = llvm::cast<Instruction>(FunctionUser);
116
- DxilInst_DispatchMesh DispatchMesh (UserInstruction);
117
-
118
- llvm::IRBuilder<> B (UserInstruction);
119
-
120
- Constant *Zero32Arg = HlslOP->GetU32Const (0 );
121
- Constant *One32Arg = HlslOP->GetU32Const (1 );
122
- Constant *Two32Arg = HlslOP->GetU32Const (2 );
123
-
124
- auto GroupIdFunc =
125
- HlslOP->GetOpFunc (DXIL::OpCode::GroupId, Type::getInt32Ty (Ctx));
126
- Constant *GroupIdOpcode =
127
- HlslOP->GetU32Const ((unsigned )DXIL::OpCode::GroupId);
128
- auto *GroupIdX =
129
- B.CreateCall (GroupIdFunc, {GroupIdOpcode, Zero32Arg}, " GroupIdX" );
130
- auto *GroupIdY =
131
- B.CreateCall (GroupIdFunc, {GroupIdOpcode, One32Arg}, " GroupIdY" );
132
- auto *GroupIdZ =
133
- B.CreateCall (GroupIdFunc, {GroupIdOpcode, Two32Arg}, " GroupIdZ" );
134
-
135
- // FlatGroupID = z + y*numZ + x*numY*numZ
136
- // Where x,y,z are the group ID components, and numZ and numY are the
137
- // corresponding AS group-count arguments to the DispatchMesh Direct3D API
138
- auto *GroupYxNumZ = B.CreateMul (
139
- GroupIdY, HlslOP->GetU32Const (m_DispatchArgumentZ), " GroupYxNumZ" );
140
- auto *FlatGroupNumZY = B.CreateAdd (GroupIdZ, GroupYxNumZ, " FlatGroupNumZY" );
141
- auto *GroupXxNumYZ = B.CreateMul (
142
- GroupIdX,
143
- HlslOP->GetU32Const (m_DispatchArgumentY * m_DispatchArgumentZ),
144
- " GroupXxNumYZ" );
145
- auto *FlatGroupID =
146
- B.CreateAdd (GroupXxNumYZ, FlatGroupNumZY, " FlatGroFlatGroupIDupNum" );
147
-
148
- // The ultimate goal is a single unique thread ID for this AS thread.
149
- // So take the flat group number, multiply it by the number of
150
- // threads per group...
151
- auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul (
152
- FlatGroupID,
153
- HlslOP->GetU32Const (DM.GetNumThreads (0 ) * DM.GetNumThreads (1 ) *
154
- DM.GetNumThreads (2 )),
155
- " FlatGroupIDWithSpaceForThreadInGroupId" );
156
-
157
- auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc (
158
- DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty (Ctx));
159
- Constant *FlattenedThreadIdInGroupOpcode =
160
- HlslOP->GetU32Const ((unsigned )DXIL::OpCode::FlattenedThreadIdInGroup);
161
- auto FlatThreadIdInGroup = B.CreateCall (FlattenedThreadIdInGroupFunc,
162
- {FlattenedThreadIdInGroupOpcode},
163
- " FlattenedThreadIdInGroup" );
164
-
165
- // ...and add the flat thread id:
166
- auto *FlatId = B.CreateAdd (FlatGroupIDWithSpaceForThreadInGroupId,
167
- FlatThreadIdInGroup, " FlatId" );
168
-
169
- AddValueToExpandedPayload (HlslOP, B, expanded, NewStructAlloca,
170
- OriginalPayloadStructType->getStructNumElements (),
171
- FlatId);
172
- AddValueToExpandedPayload (
173
- HlslOP, B, expanded, NewStructAlloca,
174
- OriginalPayloadStructType->getStructNumElements () + 1 ,
175
- DispatchMesh.get_threadGroupCountY ());
176
- AddValueToExpandedPayload (
177
- HlslOP, B, expanded, NewStructAlloca,
178
- OriginalPayloadStructType->getStructNumElements () + 2 ,
179
- DispatchMesh.get_threadGroupCountZ ());
180
- }
181
-
182
- DM.ReEmitDxilResources ();
183
194
184
- return true ;
195
+ return false ;
185
196
}
186
197
187
198
char DxilPIXAddTidToAmplificationShaderPayload::ID = 0 ;
0 commit comments