Skip to content

Commit 2abf74d

Browse files
committed
Match pass behavior with other attach target passes. Split check line. Update comment.
1 parent 36edd3f commit 2abf74d

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
258258
];
259259
}
260260

261-
def GpuXeVMAttachTarget : Pass<"xevm-attach-target", "mlir::gpu::GPUModuleOp"> {
261+
def GpuXeVMAttachTarget : Pass<"xevm-attach-target", ""> {
262262
let summary = "Attaches a XeVM target attribute to a GPU Module.";
263263
let description = [{
264264
This pass searches for all GPU Modules in the immediate regions and attaches

mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- XeVMAttachTarget.cpp - DESC -----------------------------*- C++ -*-===//
1+
//===-- XeVMAttachTarget.cpp - Attach an XeVM target ----------------------===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -68,21 +68,24 @@ void XeVMAttachTarget::runOnOperation() {
6868
OpBuilder builder(&getContext());
6969
ArrayRef<std::string> libs(linkLibs);
7070
SmallVector<StringRef> filesToLink(libs);
71-
auto target = builder.getAttr<mlir::xevm::XeVMTargetAttr>(
71+
auto target = builder.getAttr<xevm::XeVMTargetAttr>(
7272
optLevel, triple, chip, getFlags(builder),
7373
filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink));
7474
llvm::Regex matcher(moduleMatcher);
75-
// Check if the name of the module matches.
76-
auto gpuModule = cast<gpu::GPUModuleOp>(getOperation());
77-
if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
78-
return;
79-
// Create the target array.
80-
SmallVector<Attribute> targets;
81-
if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
82-
targets.append(attrs->getValue().begin(), attrs->getValue().end());
83-
targets.push_back(target);
84-
// Remove any duplicate targets.
85-
targets.erase(llvm::unique(targets), targets.end());
86-
// Update the target attribute array.
87-
gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
75+
for (Region &region : getOperation()->getRegions())
76+
for (Block &block : region.getBlocks())
77+
for (auto module : block.getOps<gpu::GPUModuleOp>()) {
78+
// Check if the name of the module matches.
79+
if (!moduleMatcher.empty() && !matcher.match(module.getName()))
80+
continue;
81+
// Create the target array.
82+
SmallVector<Attribute> targets;
83+
if (std::optional<ArrayAttr> attrs = module.getTargets())
84+
targets.append(attrs->getValue().begin(), attrs->getValue().end());
85+
targets.push_back(target);
86+
// Remove any duplicate targets.
87+
targets.erase(llvm::unique(targets), targets.end());
88+
// Update the target attribute array.
89+
module.setTargetsAttr(builder.getArrayAttr(targets));
90+
}
8891
}

mlir/test/Dialect/LLVMIR/attach-targets.mlir

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@ gpu.module @rocdl_module {
2323
gpu.module @xevm_module {
2424
}
2525
// Check the options were added.
26-
// CHECK_OPTS: @options_module_1 [#nvvm.target<O = 1, chip = "sm_70", flags = {fast, ftz}>, #rocdl.target<flags = {finite_only, no_wave64}, link = ["file1.bc", "file2.bc"]>, #xevm.target<O = 1, chip = "pvc">] {
26+
// CHECK_OPTS: @options_module_1 [#nvvm.target<O = 1, chip = "sm_70", flags = {fast, ftz}>,
27+
// CHECK_OPTS-SAME: #rocdl.target<flags = {finite_only, no_wave64}, link = ["file1.bc", "file2.bc"]>,
28+
// CHECK_OPTS-SAME: #xevm.target<O = 1, chip = "pvc">] {
2729
gpu.module @options_module_1 {
2830
}
2931
// Check the options were added and that the first target was preserved.
30-
// CHECK_OPTS: @options_module_2 [#nvvm.target<O = 3, chip = "sm_90">, #nvvm.target<O = 1, chip = "sm_70", flags = {fast, ftz}>, #rocdl.target<flags = {finite_only, no_wave64}, link = ["file1.bc", "file2.bc"]>, #xevm.target<O = 1, chip = "pvc">] {
32+
// CHECK_OPTS: @options_module_2 [#nvvm.target<O = 3, chip = "sm_90">,
33+
// CHECK_OPTS-SAME: #nvvm.target<O = 1, chip = "sm_70", flags = {fast, ftz}>,
34+
// CHECK_OPTS-SAME: #rocdl.target<flags = {finite_only, no_wave64}, link = ["file1.bc", "file2.bc"]>,
35+
// CHECK_OPTS-SAME: #xevm.target<O = 1, chip = "pvc">] {
3136
gpu.module @options_module_2 [#nvvm.target<O = 3, chip = "sm_90">] {
3237
}
3338
}

0 commit comments

Comments
 (0)