@@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
1725
1725
}
1726
1726
};
1727
1727
1728
- static bool isDeviceAllocation (mlir::Value val) {
1728
+ static bool isDeviceAllocation (mlir::Value val, mlir::Value adaptorVal ) {
1729
1729
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp ()))
1730
- return isDeviceAllocation (loadOp.getMemref ());
1730
+ return isDeviceAllocation (loadOp.getMemref (), {} );
1731
1731
if (auto boxAddrOp =
1732
1732
mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp ()))
1733
- return isDeviceAllocation (boxAddrOp.getVal ());
1733
+ return isDeviceAllocation (boxAddrOp.getVal (), {} );
1734
1734
if (auto convertOp =
1735
1735
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp ()))
1736
- return isDeviceAllocation (convertOp.getValue ());
1736
+ return isDeviceAllocation (convertOp.getValue (), {});
1737
+ if (!val.getDefiningOp () && adaptorVal) {
1738
+ if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) {
1739
+ if (blockArg.getOwner () && blockArg.getOwner ()->getParentOp () &&
1740
+ blockArg.getOwner ()->isEntryBlock ()) {
1741
+ if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>(
1742
+ *blockArg.getOwner ()->getParentOp ())) {
1743
+ auto argAttrs = func.getArgAttrs (blockArg.getArgNumber ());
1744
+ for (auto attr : argAttrs) {
1745
+ if (attr.getName ().getValue ().ends_with (cuf::getDataAttrName ())) {
1746
+ auto dataAttr =
1747
+ mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue ());
1748
+ if (dataAttr.getValue () != cuf::DataAttribute::Pinned &&
1749
+ dataAttr.getValue () != cuf::DataAttribute::Unified)
1750
+ return true ;
1751
+ }
1752
+ }
1753
+ }
1754
+ }
1755
+ }
1756
+ }
1737
1757
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp ()))
1738
1758
if (callOp.getCallee () &&
1739
1759
(callOp.getCallee ().value ().getRootReference ().getValue ().starts_with (
@@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
1928
1948
if (fir::isDerivedTypeWithLenParams (boxTy))
1929
1949
TODO (loc, " fir.embox codegen of derived with length parameters" );
1930
1950
mlir::Value result = placeInMemoryIfNotGlobalInit (
1931
- rewriter, loc, boxTy, dest, isDeviceAllocation (xbox.getMemref ()));
1951
+ rewriter, loc, boxTy, dest,
1952
+ isDeviceAllocation (xbox.getMemref (), adaptor.getMemref ()));
1932
1953
rewriter.replaceOp (xbox, result);
1933
1954
return mlir::success ();
1934
1955
}
@@ -2052,9 +2073,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
2052
2073
dest = insertStride (rewriter, loc, dest, dim, std::get<1 >(iter.value ()));
2053
2074
}
2054
2075
dest = insertBaseAddress (rewriter, loc, dest, base);
2055
- mlir::Value result =
2056
- placeInMemoryIfNotGlobalInit ( rewriter, rebox.getLoc (), destBoxTy, dest,
2057
- isDeviceAllocation (rebox.getBox ()));
2076
+ mlir::Value result = placeInMemoryIfNotGlobalInit (
2077
+ rewriter, rebox.getLoc (), destBoxTy, dest,
2078
+ isDeviceAllocation (rebox. getBox (), rebox.getBox ()));
2058
2079
rewriter.replaceOp (rebox, result);
2059
2080
return mlir::success ();
2060
2081
}
0 commit comments