Skip to content

Commit 66e41a1

Browse files
authored
[MLIR][NVVM] Declare InferIntRangeInterface for RangeableRegisterOp (llvm#122263)
1 parent 98e5962 commit 66e41a1

File tree

4 files changed

+61
-2
lines changed

4 files changed

+61
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2020
#include "mlir/IR/Dialect.h"
2121
#include "mlir/IR/OpDefinition.h"
22+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2223
#include "mlir/Interfaces/SideEffectInterfaces.h"
2324
#include "llvm/IR/IntrinsicsNVPTX.h"
2425

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1919
include "mlir/Interfaces/SideEffectInterfaces.td"
2020
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
21+
include "mlir/Interfaces/InferIntRangeInterface.td"
2122

2223
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2324
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -134,8 +135,8 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
134135
let assemblyFormat = "attr-dict `:` type($res)";
135136
}
136137

137-
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
138-
NVVM_SpecialRegisterOp<mnemonic, traits> {
138+
class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
139+
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
139140
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
140141
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
141142
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -147,6 +148,17 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
147148
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
148149
}]>
149150
];
151+
152+
// Define this method for the InferIntRangeInterface.
153+
let extraClassDefinition = [{
154+
// Infer the result ranges based on the range attribute.
155+
void $cppClass::inferResultRanges(
156+
ArrayRef<::mlir::ConstantIntRanges> argRanges,
157+
SetIntRangeFn setResultRanges) {
158+
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
159+
}
160+
}];
161+
150162
}
151163

152164
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,17 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
11581158
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
11591159
}
11601160

1161+
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1162+
/// have ConstantRangeAttr.
1163+
static void nvvmInferResultRanges(Operation *op, Value result,
1164+
ArrayRef<::mlir::ConstantIntRanges> argRanges,
1165+
SetIntRangeFn setResultRanges) {
1166+
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1167+
setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1168+
rangeAttr.getLower(), rangeAttr.getUpper()});
1169+
}
1170+
}
1171+
11611172
//===----------------------------------------------------------------------===//
11621173
// NVVMDialect initialization, type parsing, and registration.
11631174
//===----------------------------------------------------------------------===//
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
2+
gpu.module @module{
3+
gpu.func @kernel_1() kernel {
4+
%tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
5+
%tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
6+
%tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
7+
%c64 = arith.constant 64 : i32
8+
9+
%1 = arith.cmpi sgt, %tidx, %c64 : i32
10+
scf.if %1 {
11+
gpu.printf "threadidx"
12+
}
13+
%2 = arith.cmpi sgt, %tidy, %c64 : i32
14+
scf.if %2 {
15+
gpu.printf "threadidy"
16+
}
17+
%3 = arith.cmpi sgt, %tidz, %c64 : i32
18+
scf.if %3 {
19+
gpu.printf "threadidz"
20+
}
21+
gpu.return
22+
}
23+
}
24+
25+
// CHECK-LABEL: gpu.func @kernel_1
26+
// CHECK: %[[false:.+]] = arith.constant false
27+
// CHECK: %[[c64_i32:.+]] = arith.constant 64 : i32
28+
// CHECK: %[[S0:.+]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
29+
// CHECK: scf.if %[[false]] {
30+
// CHECK: gpu.printf "threadidx"
31+
// CHECK: %[[S1:.+]] = arith.cmpi sgt, %[[S0]], %[[c64_i32]] : i32
32+
// CHECK: scf.if %[[S1]] {
33+
// CHECK: gpu.printf "threadidy"
34+
// CHECK: scf.if %[[false]] {
35+
// CHECK: gpu.printf "threadidz"

0 commit comments

Comments
 (0)