Skip to content

Commit ef6165d

Browse files
[NVVM][MLIR] Remove Pure trait from clock, clock64, globaltimer Ops
This commit removes Pure trait from clock, clock64 and globaltimer Ops by creating NVVM_NCSpecialRegisterOp class to represent Ops which return non-constant values. This prevents CSE pass from optimizing away redundant uses of them
1 parent 6a99326 commit ef6165d

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,21 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
153153
// NVVM special register op definitions
154154
//===----------------------------------------------------------------------===//
155155

156-
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
156+
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
157157
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
158158
let arguments = (ins);
159159
let assemblyFormat = "attr-dict `:` type($res)";
160160
}
161161

162+
// NVVM_SpecialRegisterOp represents a non-constant special register
163+
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
164+
NVVM_IntrOp<mnemonic, traits, 1> {
165+
let arguments = (ins);
166+
let assemblyFormat = "attr-dict `:` type($res)";
167+
}
168+
162169
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
163-
NVVM_SpecialRegisterOp<mnemonic,
170+
NVVM_PureSpecialRegisterOp<mnemonic,
164171
!listconcat(traits,
165172
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
166173
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
@@ -199,11 +206,11 @@ def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
199206

200207
//===----------------------------------------------------------------------===//
201208
// Lane Mask Comparison Ops
202-
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
203-
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
204-
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
205-
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
206-
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
209+
def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
210+
def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
211+
def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
212+
def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
213+
def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
207214

208215
//===----------------------------------------------------------------------===//
209216
// Thread index and range
@@ -256,7 +263,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
256263
//===----------------------------------------------------------------------===//
257264
// envreg registers
258265
foreach index = !range(0, 32) in {
259-
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
266+
def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
260267
}
261268

262269
//===----------------------------------------------------------------------===//
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @nvvm_special_regs_clock
4+
llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
5+
%0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
6+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
7+
%1 = nvvm.read.ptx.sreg.clock : i32
8+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
9+
%2 = nvvm.read.ptx.sreg.clock : i32
10+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
11+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
12+
llvm.return %5: !llvm.struct<(i32, i32)>
13+
}
14+
15+
// CHECK-LABEL: @nvvm_special_regs_clock64
16+
llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
17+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
18+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
19+
%1 = nvvm.read.ptx.sreg.clock64 : i64
20+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
21+
%2 = nvvm.read.ptx.sreg.clock64 : i64
22+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
23+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
24+
llvm.return %5: !llvm.struct<(i64, i64)>
25+
}
26+
27+
// CHECK-LABEL: @nvvm_special_regs_globaltimer
28+
llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
29+
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
30+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
31+
%1 = nvvm.read.ptx.sreg.globaltimer : i64
32+
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
33+
%2 = nvvm.read.ptx.sreg.globaltimer : i64
34+
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
35+
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
36+
llvm.return %5: !llvm.struct<(i64, i64)>
37+
}

0 commit comments

Comments
 (0)