Skip to content

Commit 9cb9f2c

Browse files
committed
update mapa
1 parent 786ccb2 commit 9cb9f2c

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,10 +3009,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp
30093009
// NVVM Mapa Op
30103010
//===----------------------------------------------------------------------===//
30113011

3012+
// Helper predicates for address space checking
3013+
def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
3014+
def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
3015+
def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
3016+
3017+
class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
3018+
PredOpTrait<"valid address space mapping for NVVM mapa operation",
3019+
Or<[
3020+
// Generic -> Generic
3021+
And<[
3022+
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
3023+
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
3024+
]>,
3025+
// Shared -> SharedCluster
3026+
And<[
3027+
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
3028+
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
3029+
]>
3030+
]>>;
3031+
30123032
def NVVM_MapaOp: NVVM_Op<"mapa",
3013-
[TypesMatchWith<"`res` and `a` should have the same type",
3014-
"a", "res", "$_self">, NVVMRequiresSM<90>]> {
3015-
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
3033+
[NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
3034+
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
30163035
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
30173036

30183037
string llvmBuilder = [{

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
12011201
// -----
12021202

12031203
func.func @mapa(%a: !llvm.ptr, %b : i32) {
1204-
// expected-error @below {{`res` and `a` should have the same type}}
1205-
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
1204+
// expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
1205+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
12061206
return
12071207
}
12081208

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
552552
// CHECK: nvvm.mapa %{{.*}}
553553
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
554554
// CHECK: nvvm.mapa %{{.*}}
555-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
555+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
556556
return
557557
}
558558

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
760760
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
761761
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
762762
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
763-
// CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
764-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
763+
// CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
764+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
765765
llvm.return
766766
}
767767

0 commit comments

Comments
 (0)