Skip to content

[NVVM][MLIR] Remove Pure trait from clock, clock64, globaltimer Ops #147608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,23 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
// NVVM special register op definitions
//===----------------------------------------------------------------------===//

class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
// NVVM_PureSpecialRegisterOp represents special register ops that can
// speculated and does not touch memory. These operations are always
// legal to hoist or sink.
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
let arguments = (ins);
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, traits, 1> {
let arguments = (ins);
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_SpecialRegisterOp<mnemonic,
NVVM_PureSpecialRegisterOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
Expand Down Expand Up @@ -199,11 +208,11 @@ def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;

//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;

//===----------------------------------------------------------------------===//
// Thread index and range
Expand Down Expand Up @@ -256,7 +265,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
//===----------------------------------------------------------------------===//
// envreg registers
foreach index = !range(0, 32) in {
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
}

//===----------------------------------------------------------------------===//
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Dialect/LLVMIR/cse-nvvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @nvvm_special_regs_clock
llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
%0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
%1 = nvvm.read.ptx.sreg.clock : i32
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
%2 = nvvm.read.ptx.sreg.clock : i32
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
llvm.return %5: !llvm.struct<(i32, i32)>
}

// CHECK-LABEL: @nvvm_special_regs_clock64
llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
%1 = nvvm.read.ptx.sreg.clock64 : i64
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
%2 = nvvm.read.ptx.sreg.clock64 : i64
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
llvm.return %5: !llvm.struct<(i64, i64)>
}

// CHECK-LABEL: @nvvm_special_regs_globaltimer
llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
%1 = nvvm.read.ptx.sreg.globaltimer : i64
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
%2 = nvvm.read.ptx.sreg.globaltimer : i64
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
llvm.return %5: !llvm.struct<(i64, i64)>
}