|
1 |
| -//===-- XeVMAttachTarget.cpp - DESC -----------------------------*- C++ -*-===// |
| 1 | +//===-- XeVMAttachTarget.cpp - Attach an XeVM target ----------------------===// |
2 | 2 | //
|
3 | 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
4 | 4 | // See https://llvm.org/LICENSE.txt for license information.
|
@@ -68,21 +68,24 @@ void XeVMAttachTarget::runOnOperation() {
|
68 | 68 | OpBuilder builder(&getContext());
|
69 | 69 | ArrayRef<std::string> libs(linkLibs);
|
70 | 70 | SmallVector<StringRef> filesToLink(libs);
|
71 |
| - auto target = builder.getAttr<mlir::xevm::XeVMTargetAttr>( |
| 71 | + auto target = builder.getAttr<xevm::XeVMTargetAttr>( |
72 | 72 | optLevel, triple, chip, getFlags(builder),
|
73 | 73 | filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink));
|
74 | 74 | llvm::Regex matcher(moduleMatcher);
|
75 |
| - // Check if the name of the module matches. |
76 |
| - auto gpuModule = cast<gpu::GPUModuleOp>(getOperation()); |
77 |
| - if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName())) |
78 |
| - return; |
79 |
| - // Create the target array. |
80 |
| - SmallVector<Attribute> targets; |
81 |
| - if (std::optional<ArrayAttr> attrs = gpuModule.getTargets()) |
82 |
| - targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
83 |
| - targets.push_back(target); |
84 |
| - // Remove any duplicate targets. |
85 |
| - targets.erase(llvm::unique(targets), targets.end()); |
86 |
| - // Update the target attribute array. |
87 |
| - gpuModule.setTargetsAttr(builder.getArrayAttr(targets)); |
| 75 | + for (Region ®ion : getOperation()->getRegions()) |
| 76 | + for (Block &block : region.getBlocks()) |
| 77 | + for (auto module : block.getOps<gpu::GPUModuleOp>()) { |
| 78 | + // Check if the name of the module matches. |
| 79 | + if (!moduleMatcher.empty() && !matcher.match(module.getName())) |
| 80 | + continue; |
| 81 | + // Create the target array. |
| 82 | + SmallVector<Attribute> targets; |
| 83 | + if (std::optional<ArrayAttr> attrs = module.getTargets()) |
| 84 | + targets.append(attrs->getValue().begin(), attrs->getValue().end()); |
| 85 | + targets.push_back(target); |
| 86 | + // Remove any duplicate targets. |
| 87 | + targets.erase(llvm::unique(targets), targets.end()); |
| 88 | + // Update the target attribute array. |
| 89 | + module.setTargetsAttr(builder.getArrayAttr(targets)); |
| 90 | + } |
88 | 91 | }
|
0 commit comments