Skip to content

Commit d66f59a

Browse files
committed
update mapa
1 parent 8dbf92e commit d66f59a

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,10 +2600,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp
26002600
// NVVM Mapa Op
26012601
//===----------------------------------------------------------------------===//
26022602

2603+
// Helper predicates for address space checking
2604+
def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
2605+
def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
2606+
def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
2607+
2608+
class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
2609+
PredOpTrait<"valid address space mapping for NVVM mapa operation",
2610+
Or<[
2611+
// Generic -> Generic
2612+
And<[
2613+
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
2614+
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
2615+
]>,
2616+
// Shared -> SharedCluster
2617+
And<[
2618+
SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
2619+
SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
2620+
]>
2621+
]>>;
2622+
26032623
def NVVM_MapaOp: NVVM_Op<"mapa",
2604-
[TypesMatchWith<"`res` and `a` should have the same type",
2605-
"a", "res", "$_self">]> {
2606-
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
2624+
[NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
2625+
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
26072626
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
26082627

26092628
string llvmBuilder = [{

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
12011201
// -----
12021202

12031203
func.func @mapa(%a: !llvm.ptr, %b : i32) {
1204-
// expected-error @below {{`res` and `a` should have the same type}}
1205-
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
1204+
// expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
1205+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
12061206
return
12071207
}
12081208

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
552552
// CHECK: nvvm.mapa %{{.*}}
553553
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
554554
// CHECK: nvvm.mapa %{{.*}}
555-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
555+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
556556
return
557557
}
558558

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
769769
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
770770
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
771771
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
772-
// CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
773-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
772+
// CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
773+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
774774
llvm.return
775775
}
776776

0 commit comments

Comments
 (0)