From 9c0d0f678d060d4de8e234a90ec59fb48666dda0 Mon Sep 17 00:00:00 2001 From: pradeepku Date: Mon, 7 Jul 2025 14:05:00 +0530 Subject: [PATCH] [NVVM][MLIR] Remove Pure trait from clock, clock64, globaltimer Ops This commit removes Pure trait from clock, clock64 and globaltimer Ops by creating NVVM_NCSpecialRegisterOp class to represent Ops which return non-constant values. This prevents CSE pass from optimizing away redundant uses of them --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 90 +++++++++++---------- mlir/test/Dialect/LLVMIR/cse-nvvm.mlir | 37 +++++++++ 2 files changed, 85 insertions(+), 42 deletions(-) create mode 100644 mlir/test/Dialect/LLVMIR/cse-nvvm.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6895e946b8a45..45a8904375e2b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -153,14 +153,20 @@ class NVVM_IntrOp traits = [], // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp traits = []> : +class NVVM_PureSpecialRegisterOp traits = []> : NVVM_IntrOp { let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } -class NVVM_SpecialRangeableRegisterOp traits = []> : - NVVM_SpecialRegisterOp traits = []> : + NVVM_IntrOp { + let arguments = (ins); + let assemblyFormat = "attr-dict `:` type($res)"; +} + +class NVVM_PureSpecialRangeableRegisterOp traits = []> : + NVVM_PureSpecialRegisterOp])> { let arguments = (ins OptionalAttr:$range); @@ -189,63 +195,63 @@ class NVVM_SpecialRangeableRegisterOp traits = []> //===----------------------------------------------------------------------===// // Lane, Warp, SM, Grid index and range -def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">; -def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">; -def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">; -def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">; -def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">; -def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">; -def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">; +def NVVM_LaneIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.laneid">; +def NVVM_WarpSizeOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">; +def NVVM_WarpIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpid">; +def NVVM_WarpDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">; +def NVVM_SmIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.smid">; +def NVVM_SmDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">; +def NVVM_GridIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.gridid">; //===----------------------------------------------------------------------===// // Lane Mask Comparison Ops -def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">; -def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">; -def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">; -def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">; -def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">; +def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">; +def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">; +def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">; +def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">; +def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">; //===----------------------------------------------------------------------===// // Thread index and range -def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">; -def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">; -def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">; -def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">; -def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">; -def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">; +def NVVM_ThreadIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">; +def NVVM_ThreadIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">; +def NVVM_ThreadIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">; +def NVVM_BlockDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">; +def NVVM_BlockDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">; +def NVVM_BlockDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">; //===----------------------------------------------------------------------===// // Block index and range -def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">; -def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">; -def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">; -def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">; -def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">; -def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">; +def NVVM_BlockIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">; +def NVVM_BlockIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">; +def NVVM_BlockIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">; +def NVVM_GridDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">; +def NVVM_GridDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">; +def NVVM_GridDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">; //===----------------------------------------------------------------------===// // CTA Cluster index and range -def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>; -def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">; -def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">; -def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">; -def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">; -def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">; +def NVVM_ClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>; +def NVVM_ClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">; +def NVVM_ClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">; +def NVVM_ClusterDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">; +def NVVM_ClusterDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">; +def NVVM_ClusterDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">; //===----------------------------------------------------------------------===// // CTA index and range within Cluster -def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>; -def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>; -def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>; -def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>; -def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>; -def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; +def NVVM_BlockInClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>; +def NVVM_BlockInClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>; +def NVVM_BlockInClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDimBlocksXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDimBlocksYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; //===----------------------------------------------------------------------===// // CTA index and across Cluster dimensions -def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>; -def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">; +def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>; +def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">; //===----------------------------------------------------------------------===// // Clock registers @@ -256,7 +262,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">; //===----------------------------------------------------------------------===// // envreg registers foreach index = !range(0, 32) in { - def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>; + def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir b/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir new file mode 100644 index 0000000000000..8d24c3846f178 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/cse-nvvm.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @nvvm_special_regs_clock +llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> { + %0 = llvm.mlir.zero: !llvm.struct<(i32, i32)> + // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock + %1 = nvvm.read.ptx.sreg.clock : i32 + // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock + %2 = nvvm.read.ptx.sreg.clock : i32 + %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)> + %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)> + llvm.return %5: !llvm.struct<(i32, i32)> +} + +// CHECK-LABEL: @nvvm_special_regs_clock64 +llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> { + %0 = llvm.mlir.zero: !llvm.struct<(i64, i64)> + // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64 + %1 = nvvm.read.ptx.sreg.clock64 : i64 + // CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64 + %2 = nvvm.read.ptx.sreg.clock64 : i64 + %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)> + %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)> + llvm.return %5: !llvm.struct<(i64, i64)> +} + +// CHECK-LABEL: @nvvm_special_regs_globaltimer +llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> { + %0 = llvm.mlir.zero: !llvm.struct<(i64, i64)> + // CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer + %1 = nvvm.read.ptx.sreg.globaltimer : i64 + // CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer + %2 = nvvm.read.ptx.sreg.globaltimer : i64 + %4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)> + %5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)> + llvm.return %5: !llvm.struct<(i64, i64)> +}