Skip to content

Commit 5cd56c9

Browse files
[MLIR][NVVM] Remove Pure trait from clock, clock64, globaltimer Ops (#147608)
This commit removes Pure trait from clock, clock64 and globaltimer Ops by creating NVVM_NCSpecialRegisterOp class to represent Ops which return non-constant values. This prevents CSE pass from optimizing awayredundant uses of them
1 parent aee21c3 commit 5cd56c9

File tree

2 files changed

+85
-42
lines changed

2 files changed

+85
-42
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,20 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
153153
// NVVM special register op definitions
154154
//===----------------------------------------------------------------------===//
155155

156-
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
156+
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
157157
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
158158
let arguments = (ins);
159159
let assemblyFormat = "attr-dict `:` type($res)";
160160
}
161161

162-
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
163-
NVVM_SpecialRegisterOp<mnemonic,
162+
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
163+
NVVM_IntrOp<mnemonic, traits, 1> {
164+
let arguments = (ins);
165+
let assemblyFormat = "attr-dict `:` type($res)";
166+
}
167+
168+
class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
169+
NVVM_PureSpecialRegisterOp<mnemonic,
164170
!listconcat(traits,
165171
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
166172
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
@@ -189,63 +195,63 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
189195

190196
//===----------------------------------------------------------------------===//
191197
// Lane, Warp, SM, Grid index and range
192-
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
193-
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
194-
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
195-
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
196-
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
197-
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
198-
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
198+
def NVVM_LaneIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
199+
def NVVM_WarpSizeOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
200+
def NVVM_WarpIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
201+
def NVVM_WarpDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
202+
def NVVM_SmIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
203+
def NVVM_SmDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
204+
def NVVM_GridIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
199205

200206
//===----------------------------------------------------------------------===//
201207
// Lane Mask Comparison Ops
202-
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
203-
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
204-
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
205-
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
206-
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
208+
def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
209+
def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
210+
def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
211+
def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
212+
def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
207213

208214
//===----------------------------------------------------------------------===//
209215
// Thread index and range
210-
def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
211-
def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
212-
def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
213-
def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
214-
def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
215-
def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
216+
def NVVM_ThreadIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
217+
def NVVM_ThreadIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
218+
def NVVM_ThreadIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
219+
def NVVM_BlockDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
220+
def NVVM_BlockDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
221+
def NVVM_BlockDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
216222

217223
//===----------------------------------------------------------------------===//
218224
// Block index and range
219-
def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
220-
def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
221-
def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
222-
def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
223-
def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
224-
def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
225+
def NVVM_BlockIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
226+
def NVVM_BlockIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
227+
def NVVM_BlockIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
228+
def NVVM_GridDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
229+
def NVVM_GridDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
230+
def NVVM_GridDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
225231

226232
//===----------------------------------------------------------------------===//
227233
// CTA Cluster index and range
228-
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
229-
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
230-
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
231-
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
232-
def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
233-
def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
234+
def NVVM_ClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
235+
def NVVM_ClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
236+
def NVVM_ClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
237+
def NVVM_ClusterDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
238+
def NVVM_ClusterDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
239+
def NVVM_ClusterDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
234240

235241

236242
//===----------------------------------------------------------------------===//
237243
// CTA index and range within Cluster
238-
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
239-
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
240-
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
241-
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
242-
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
243-
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
244+
def NVVM_BlockInClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
245+
def NVVM_BlockInClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
246+
def NVVM_BlockInClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
247+
def NVVM_ClusterDimBlocksXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
248+
def NVVM_ClusterDimBlocksYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
249+
def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
244250

245251
//===----------------------------------------------------------------------===//
246252
// CTA index and across Cluster dimensions
247-
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
248-
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
253+
def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
254+
def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
249255

250256
//===----------------------------------------------------------------------===//
251257
// Clock registers
@@ -256,7 +262,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
256262
//===----------------------------------------------------------------------===//
257263
// envreg registers
258264
foreach index = !range(0, 32) in {
259-
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
265+
def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
260266
}
261267

262268
//===----------------------------------------------------------------------===//
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @nvvm_special_regs_clock
4+
llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
5+
%0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
6+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
7+
%1 = nvvm.read.ptx.sreg.clock : i32
8+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
9+
%2 = nvvm.read.ptx.sreg.clock : i32
10+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
11+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
12+
llvm.return %5: !llvm.struct<(i32, i32)>
13+
}
14+
15+
// CHECK-LABEL: @nvvm_special_regs_clock64
16+
llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
17+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
18+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
19+
%1 = nvvm.read.ptx.sreg.clock64 : i64
20+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
21+
%2 = nvvm.read.ptx.sreg.clock64 : i64
22+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
23+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
24+
llvm.return %5: !llvm.struct<(i64, i64)>
25+
}
26+
27+
// CHECK-LABEL: @nvvm_special_regs_globaltimer
28+
llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
29+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
30+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
31+
%1 = nvvm.read.ptx.sreg.globaltimer : i64
32+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
33+
%2 = nvvm.read.ptx.sreg.globaltimer : i64
34+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
35+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
36+
llvm.return %5: !llvm.struct<(i64, i64)>
37+
}

0 commit comments

Comments
 (0)