@@ -2391,6 +2391,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2391
2391
queue, item, clauseOps);
2392
2392
}
2393
2393
2394
+ static mlir::FlatSymbolRefAttr
2395
+ genImplicitDefaultDeclareMapper (lower::AbstractConverter &converter,
2396
+ mlir::Location loc, fir::RecordType recordType,
2397
+ llvm::StringRef mapperNameStr) {
2398
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2399
+ lower::StatementContext stmtCtx;
2400
+
2401
+ // Save current insertion point before moving to the module scope to create
2402
+ // the DeclareMapperOp
2403
+ mlir::OpBuilder::InsertionGuard guard (firOpBuilder);
2404
+
2405
+ firOpBuilder.setInsertionPointToStart (converter.getModuleOp ().getBody ());
2406
+ auto declMapperOp = firOpBuilder.create <mlir::omp::DeclareMapperOp>(
2407
+ loc, mapperNameStr, recordType);
2408
+ auto ®ion = declMapperOp.getRegion ();
2409
+ firOpBuilder.createBlock (®ion);
2410
+ auto mapperArg = region.addArgument (firOpBuilder.getRefType (recordType), loc);
2411
+
2412
+ auto declareOp =
2413
+ firOpBuilder.create <hlfir::DeclareOp>(loc, mapperArg, /* uniq_name=*/ " " );
2414
+
2415
+ const auto genBoundsOps = [&](mlir::Value mapVal,
2416
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
2417
+ fir::ExtendedValue extVal =
2418
+ hlfir::translateToExtendedValue (mapVal.getLoc (), firOpBuilder,
2419
+ hlfir::Entity{mapVal},
2420
+ /* contiguousHint=*/ true )
2421
+ .first ;
2422
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr (
2423
+ firOpBuilder, mapVal, /* isOptional=*/ false , mapVal.getLoc ());
2424
+ bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
2425
+ mlir::omp::MapBoundsType>(
2426
+ firOpBuilder, info, extVal,
2427
+ /* dataExvIsAssumedSize=*/ false , mapVal.getLoc ());
2428
+ };
2429
+
2430
+ // Return a reference to the contents of a derived type with one field.
2431
+ // Also return the field type.
2432
+ const auto getFieldRef =
2433
+ [&](mlir::Value rec,
2434
+ unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
2435
+ auto recType = mlir::dyn_cast<fir::RecordType>(
2436
+ fir::unwrapPassByRefType (rec.getType ()));
2437
+ auto [fieldName, fieldTy] = recType.getTypeList ()[index];
2438
+ mlir::Value field = firOpBuilder.create <fir::FieldIndexOp>(
2439
+ loc, fir::FieldType::get (recType.getContext ()), fieldName, recType,
2440
+ fir::getTypeParams (rec));
2441
+ return {firOpBuilder.create <fir::CoordinateOp>(
2442
+ loc, firOpBuilder.getRefType (fieldTy), rec, field),
2443
+ fieldTy};
2444
+ };
2445
+
2446
+ mlir::omp::DeclareMapperInfoOperands clauseOps;
2447
+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberPlacementIndices;
2448
+ llvm::SmallVector<mlir::Value> memberMapOps;
2449
+
2450
+ llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2451
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
2452
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
2453
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2454
+ mlir::omp::VariableCaptureKind captureKind =
2455
+ mlir::omp::VariableCaptureKind::ByRef;
2456
+ int64_t index = 0 ;
2457
+
2458
+ // Populate the declareMapper region with the map information.
2459
+ for (const auto &[memberName, memberType] :
2460
+ mlir::dyn_cast<fir::RecordType>(recordType).getTypeList ()) {
2461
+ auto [ref, type] = getFieldRef (declareOp.getBase (), index);
2462
+ mlir::FlatSymbolRefAttr mapperId;
2463
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
2464
+ std::string mapperIdName =
2465
+ recType.getName ().str () + " .omp.default.mapper" ;
2466
+ if (auto *sym = converter.getCurrentScope ().FindSymbol (mapperIdName))
2467
+ mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2468
+ else if (auto *sym = converter.getCurrentScope ().FindSymbol (memberName))
2469
+ mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2470
+
2471
+ if (converter.getModuleOp ().lookupSymbol (mapperIdName))
2472
+ mapperId = mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2473
+ mapperIdName);
2474
+ else
2475
+ mapperId = genImplicitDefaultDeclareMapper (converter, loc, recType,
2476
+ mapperIdName);
2477
+ }
2478
+
2479
+ llvm::SmallVector<mlir::Value> bounds;
2480
+ genBoundsOps (ref, bounds);
2481
+ mlir::Value mapOp = createMapInfoOp (
2482
+ firOpBuilder, loc, ref, /* varPtrPtr=*/ mlir::Value{}, " " , bounds,
2483
+ /* members=*/ {},
2484
+ /* membersIndex=*/ mlir::ArrayAttr{},
2485
+ static_cast <
2486
+ std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
2487
+ mapFlag),
2488
+ captureKind, ref.getType (), /* partialMap=*/ false , mapperId);
2489
+ memberMapOps.emplace_back (mapOp);
2490
+ memberPlacementIndices.emplace_back (llvm::SmallVector<int64_t >{index++});
2491
+ }
2492
+
2493
+ llvm::SmallVector<mlir::Value> bounds;
2494
+ genBoundsOps (declareOp.getOriginalBase (), bounds);
2495
+ mlir::omp::MapInfoOp mapOp = createMapInfoOp (
2496
+ firOpBuilder, loc, declareOp.getOriginalBase (),
2497
+ /* varPtrPtr=*/ mlir::Value (), /* name=*/ " " , bounds, memberMapOps,
2498
+ firOpBuilder.create2DI64ArrayAttr (memberPlacementIndices),
2499
+ static_cast <std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
2500
+ mapFlag),
2501
+ captureKind, declareOp.getType (0 ),
2502
+ /* partialMap=*/ true );
2503
+
2504
+ clauseOps.mapVars .emplace_back (mapOp);
2505
+
2506
+ firOpBuilder.create <mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars );
2507
+ // declMapperOp->dumpPretty();
2508
+ return mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2509
+ mapperNameStr);
2510
+ }
2511
+
2394
2512
static mlir::omp::TargetOp
2395
2513
genTargetOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2396
2514
lower::StatementContext &stmtCtx,
@@ -2467,15 +2585,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2467
2585
name << sym.name ().ToString ();
2468
2586
2469
2587
mlir::FlatSymbolRefAttr mapperId;
2470
- if (sym.GetType ()->category () == semantics::DeclTypeSpec::TypeDerived) {
2588
+ if (sym.GetType ()->category () == semantics::DeclTypeSpec::TypeDerived &&
2589
+ defaultMaps.empty ()) {
2471
2590
auto &typeSpec = sym.GetType ()->derivedTypeSpec ();
2472
2591
std::string mapperIdName =
2473
2592
typeSpec.name ().ToString () + llvm::omp::OmpDefaultMapperName;
2474
2593
if (auto *sym = converter.getCurrentScope ().FindSymbol (mapperIdName))
2475
2594
mapperIdName = converter.mangleName (mapperIdName, sym->owner ());
2595
+ else
2596
+ mapperIdName =
2597
+ converter.mangleName (mapperIdName, *typeSpec.GetScope ());
2598
+
2476
2599
if (converter.getModuleOp ().lookupSymbol (mapperIdName))
2477
2600
mapperId = mlir::FlatSymbolRefAttr::get (&converter.getMLIRContext (),
2478
2601
mapperIdName);
2602
+ else
2603
+ mapperId = genImplicitDefaultDeclareMapper (
2604
+ converter, loc,
2605
+ mlir::cast<fir::RecordType>(
2606
+ converter.genType (sym.GetType ()->derivedTypeSpec ())),
2607
+ mapperIdName);
2479
2608
}
2480
2609
2481
2610
fir::factory::AddrAndBoundsInfo info =
@@ -3442,6 +3571,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3442
3571
ClauseProcessor cp (converter, semaCtx, clauses);
3443
3572
cp.processMap (loc, stmtCtx, clauseOps);
3444
3573
firOpBuilder.create <mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars );
3574
+ // declMapperOp->dumpPretty();
3445
3575
}
3446
3576
3447
3577
static void
0 commit comments