From f6d951b96f35ae2bb0447ca791b4c5f01f1b975b Mon Sep 17 00:00:00 2001 From: Modi Mo Date: Thu, 26 Jun 2025 23:25:05 -0700 Subject: [PATCH 1/3] update mapa --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 25 ++++++++++++++++++--- mlir/test/Dialect/LLVMIR/invalid.mlir | 4 ++-- mlir/test/Dialect/LLVMIR/nvvm.mlir | 2 +- mlir/test/Target/LLVMIR/nvvmir.mlir | 4 ++-- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6895e946b8a45..e55060dc04204 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3009,10 +3009,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp // NVVM Mapa Op //===----------------------------------------------------------------------===// +// Helper predicates for address space checking +def IsGenericAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 0">; +def IsSharedAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 3">; +def IsSharedClusterAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 7">; + +class NVVM_AddressSpaceMapping : + PredOpTrait<"valid address space mapping for NVVM mapa operation", + Or<[ + // Generic -> Generic + And<[ + SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>, + SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace> + ]>, + // Shared -> SharedCluster + And<[ + SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>, + SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace> + ]> + ]>>; + def NVVM_MapaOp: NVVM_Op<"mapa", - [TypesMatchWith<"`res` and `a` should have the same type", - "a", "res", "$_self">, NVVMRequiresSM<90>]> { - let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res); + [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> { + let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); string llvmBuilder = [{ diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 251ca716c7a7a..7a85eea58c558 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // ----- func.func @mapa(%a: !llvm.ptr, %b : i32) { - // expected-error @below {{`res` and `a` should have the same type}} - %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3> + // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}} + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7> return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..4349193aa1a45 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK: nvvm.mapa %{{.*}} %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr // CHECK: nvvm.mapa %{{.*}} - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..c119c1a0fd21f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() { llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}}) %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr - // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> llvm.return } From 65552af49a860190a45ac4f742d50300839d6312 Mon Sep 17 00:00:00 2001 From: Modi Mo Date: Tue, 1 Jul 2025 15:27:23 -0700 Subject: [PATCH 2/3] review feedback --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 6 +++--- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e55060dc04204..431a33412c43e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3014,8 +3014,8 @@ def IsGenericAddressSpace : CPred<"llvm::cast($_self).get def IsSharedAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 3">; def IsSharedClusterAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 7">; -class NVVM_AddressSpaceMapping : - PredOpTrait<"valid address space mapping for NVVM mapa operation", +class NVVM_MapaASCheck : + PredOpTrait<"Valid address-space check(or mapping) for mapa Op", Or<[ // Generic -> Generic And<[ @@ -3030,7 +3030,7 @@ class NVVM_AddressSpaceMapping : ]>>; def NVVM_MapaOp: NVVM_Op<"mapa", - [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> { + [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> { let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 7a85eea58c558..2c1c3071c456e 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1201,7 +1201,7 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // ----- func.func @mapa(%a: !llvm.ptr, %b : i32) { - // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}} + // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}} %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7> return } From 088fddf0b88ee4113fb0fc848bb5e07e08cc29a5 Mon Sep 17 00:00:00 2001 From: Modi Mo Date: Thu, 3 Jul 2025 20:37:44 -0700 Subject: [PATCH 3/3] make generic helper and update usage to it --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 27 ++++----------------- mlir/include/mlir/IR/OpBase.td | 16 ++++++++++++ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 431a33412c43e..9ebaac6f0fd80 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3009,28 +3009,11 @@ def NVVM_GriddepcontrolLaunchDependentsOp // NVVM Mapa Op //===----------------------------------------------------------------------===// -// Helper predicates for address space checking -def IsGenericAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 0">; -def IsSharedAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 3">; -def IsSharedClusterAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 7">; - -class NVVM_MapaASCheck : - PredOpTrait<"Valid address-space check(or mapping) for mapa Op", - Or<[ - // Generic -> Generic - And<[ - SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>, - SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace> - ]>, - // Shared -> SharedCluster - And<[ - SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>, - SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace> - ]> - ]>>; - -def NVVM_MapaOp: NVVM_Op<"mapa", - [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> { +def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op", + Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate, + InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>; + +def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> { let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 43ef28624fb19..b21603a410c0c 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -603,6 +603,22 @@ class RangedTypesMatchWith : TypesMatchWith; +// Checks that each inputArg has the same type as the corresponding entry +// in allowedTypes +class InputMatchesTypes inputArgs, list allowedTypes> : + PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types", + !foldl(TruePred, !range(!size(inputArgs)), acc, i, + And<[acc, + SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()", + allowedTypes[i].predicate> + ]>)> { + assert !eq(!size(inputArgs), !size(allowedTypes)), + "inputArgs and allowedTypes lists must have the same length"; + + list inputArgList = inputArgs; + list allowedTypeList = allowedTypes; +} + // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : And<[ CPred<"$_op.getNumOperands() > " # idx>,