@@ -153,14 +153,23 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
153
153
// NVVM special register op definitions
154
154
//===----------------------------------------------------------------------===//
155
155
156
- class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
156
+ // NVVM_PureSpecialRegisterOp represents special register ops that can
157
+ // speculated and does not touch memory. These operations are always
158
+ // legal to hoist or sink.
159
+ class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
157
160
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
158
161
let arguments = (ins);
159
162
let assemblyFormat = "attr-dict `:` type($res)";
160
163
}
161
164
162
- class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
163
- NVVM_SpecialRegisterOp<mnemonic,
165
+ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
166
+ NVVM_IntrOp<mnemonic, traits, 1> {
167
+ let arguments = (ins);
168
+ let assemblyFormat = "attr-dict `:` type($res)";
169
+ }
170
+
171
+ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
172
+ NVVM_PureSpecialRegisterOp<mnemonic,
164
173
!listconcat(traits,
165
174
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
166
175
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
@@ -189,63 +198,63 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
189
198
190
199
//===----------------------------------------------------------------------===//
191
200
// Lane, Warp, SM, Grid index and range
192
- def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.laneid">;
193
- def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.warpsize">;
194
- def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.warpid">;
195
- def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nwarpid">;
196
- def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.smid">;
197
- def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nsmid">;
198
- def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.gridid">;
201
+ def NVVM_LaneIdOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.laneid">;
202
+ def NVVM_WarpSizeOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.warpsize">;
203
+ def NVVM_WarpIdOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.warpid">;
204
+ def NVVM_WarpDimOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nwarpid">;
205
+ def NVVM_SmIdOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.smid">;
206
+ def NVVM_SmDimOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nsmid">;
207
+ def NVVM_GridIdOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.gridid">;
199
208
200
209
//===----------------------------------------------------------------------===//
201
210
// Lane Mask Comparison Ops
202
- def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp <"read.ptx.sreg.lanemask.eq">;
203
- def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp <"read.ptx.sreg.lanemask.le">;
204
- def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp <"read.ptx.sreg.lanemask.lt">;
205
- def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp <"read.ptx.sreg.lanemask.ge">;
206
- def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp <"read.ptx.sreg.lanemask.gt">;
211
+ def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.lanemask.eq">;
212
+ def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.lanemask.le">;
213
+ def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.lanemask.lt">;
214
+ def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.lanemask.ge">;
215
+ def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.lanemask.gt">;
207
216
208
217
//===----------------------------------------------------------------------===//
209
218
// Thread index and range
210
- def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.tid.x">;
211
- def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.tid.y">;
212
- def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.tid.z">;
213
- def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ntid.x">;
214
- def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ntid.y">;
215
- def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ntid.z">;
219
+ def NVVM_ThreadIdXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.tid.x">;
220
+ def NVVM_ThreadIdYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.tid.y">;
221
+ def NVVM_ThreadIdZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.tid.z">;
222
+ def NVVM_BlockDimXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ntid.x">;
223
+ def NVVM_BlockDimYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ntid.y">;
224
+ def NVVM_BlockDimZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ntid.z">;
216
225
217
226
//===----------------------------------------------------------------------===//
218
227
// Block index and range
219
- def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.x">;
220
- def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.y">;
221
- def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.z">;
222
- def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.x">;
223
- def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.y">;
224
- def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.z">;
228
+ def NVVM_BlockIdXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.x">;
229
+ def NVVM_BlockIdYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.y">;
230
+ def NVVM_BlockIdZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.ctaid.z">;
231
+ def NVVM_GridDimXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.x">;
232
+ def NVVM_GridDimYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.y">;
233
+ def NVVM_GridDimZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nctaid.z">;
225
234
226
235
//===----------------------------------------------------------------------===//
227
236
// CTA Cluster index and range
228
- def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
229
- def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.y">;
230
- def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.z">;
231
- def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.x">;
232
- def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.y">;
233
- def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.z">;
237
+ def NVVM_ClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
238
+ def NVVM_ClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.y">;
239
+ def NVVM_ClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.clusterid.z">;
240
+ def NVVM_ClusterDimXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.x">;
241
+ def NVVM_ClusterDimYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.y">;
242
+ def NVVM_ClusterDimZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.nclusterid.z">;
234
243
235
244
236
245
//===----------------------------------------------------------------------===//
237
246
// CTA index and range within Cluster
238
- def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
239
- def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
240
- def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
241
- def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
242
- def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
243
- def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.z">;
247
+ def NVVM_BlockInClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
248
+ def NVVM_BlockInClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
249
+ def NVVM_BlockInClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
250
+ def NVVM_ClusterDimBlocksXOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
251
+ def NVVM_ClusterDimBlocksYOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
252
+ def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctaid.z">;
244
253
245
254
//===----------------------------------------------------------------------===//
246
255
// CTA index and across Cluster dimensions
247
- def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
248
- def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctarank">;
256
+ def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
257
+ def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp <"read.ptx.sreg.cluster.nctarank">;
249
258
250
259
//===----------------------------------------------------------------------===//
251
260
// Clock registers
@@ -256,7 +265,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
256
265
//===----------------------------------------------------------------------===//
257
266
// envreg registers
258
267
foreach index = !range(0, 32) in {
259
- def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp <"read.ptx.sreg.envreg" # index>;
268
+ def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp <"read.ptx.sreg.envreg" # index>;
260
269
}
261
270
262
271
//===----------------------------------------------------------------------===//
0 commit comments