Skip to content

Commit ddab115

Browse files
[NVVM][MLIR] Remove Pure trait from clock, clock64, globaltimer Ops
This commit removes Pure trait from clock, clock64 and globaltimer Ops by creating NVVM_NCSpecialRegisterOp class to represent Ops which return non-constant values. This prevents CSE pass from optimizing away redundant uses of them
1 parent 6a99326 commit ddab115

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,23 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
153153
// NVVM special register op definitions
154154
//===----------------------------------------------------------------------===//
155155

156-
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
156+
// NVVM_PureSpecialRegisterOp represents special register ops that can
157+
// speculated and does not touch memory. These operations are always
158+
// legal to hoist or sink.
159+
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
157160
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
158161
let arguments = (ins);
159162
let assemblyFormat = "attr-dict `:` type($res)";
160163
}
161164

165+
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
166+
NVVM_IntrOp<mnemonic, traits, 1> {
167+
let arguments = (ins);
168+
let assemblyFormat = "attr-dict `:` type($res)";
169+
}
170+
162171
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
163-
NVVM_SpecialRegisterOp<mnemonic,
172+
NVVM_PureSpecialRegisterOp<mnemonic,
164173
!listconcat(traits,
165174
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
166175
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
@@ -199,11 +208,11 @@ def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
199208

200209
//===----------------------------------------------------------------------===//
201210
// Lane Mask Comparison Ops
202-
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
203-
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
204-
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
205-
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
206-
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
211+
def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
212+
def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
213+
def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
214+
def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
215+
def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
207216

208217
//===----------------------------------------------------------------------===//
209218
// Thread index and range
@@ -256,7 +265,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
256265
//===----------------------------------------------------------------------===//
257266
// envreg registers
258267
foreach index = !range(0, 32) in {
259-
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
268+
def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
260269
}
261270

262271
//===----------------------------------------------------------------------===//
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @nvvm_special_regs_clock
4+
llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
5+
%0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
6+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
7+
%1 = nvvm.read.ptx.sreg.clock : i32
8+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
9+
%2 = nvvm.read.ptx.sreg.clock : i32
10+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
11+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
12+
llvm.return %5: !llvm.struct<(i32, i32)>
13+
}
14+
15+
// CHECK-LABEL: @nvvm_special_regs_clock64
16+
llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
17+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
18+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
19+
%1 = nvvm.read.ptx.sreg.clock64 : i64
20+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
21+
%2 = nvvm.read.ptx.sreg.clock64 : i64
22+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
23+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
24+
llvm.return %5: !llvm.struct<(i64, i64)>
25+
}
26+
27+
// CHECK-LABEL: @nvvm_special_regs_globaltimer
28+
llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
29+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
30+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
31+
%1 = nvvm.read.ptx.sreg.globaltimer : i64
32+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
33+
%2 = nvvm.read.ptx.sreg.globaltimer : i64
34+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
35+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
36+
llvm.return %5: !llvm.struct<(i64, i64)>
37+
}

0 commit comments

Comments
 (0)