Skip to content

Commit 4868d66

Browse files
authored
[flang] improve DITypeAttr caching with recursive derived types (#146543)
The current DITypeAttr caching for derived type debug metadata generation strategy is not optimal. This turns out to be an issue for compile times in apps with very very complex derived types like CP2K See the added debug-cyclic-derived-type-caching-simple.f90 test for more details about the duplication issue. As a real world example justifying the new non trivial caching strategy, in CP2K, emitting debug type info for the swarm_worker_type` in swarm_worker.F caused 1,747,347 llvm debug metadata nodes to be emitted instead of 8023 after this patch (200x less) leading to noticeable compile time improvements (I measured 0.12s spent in `AddDebugInfo` pass instead of 7.5s prior to this patch). The main idea is that caching is now associating to the cached DITypeAttr tree for a derived type a list of parent nodes being referred to recursively via indices in this DITypeAttr. When leaving the context of a parent node, all types that were cached and linked to this parent node are cleared from the cache. This allows more reusage in sub-trees while still fulfilling the MLIR requirements that DITypeAttr types referring to a parent DITypeAttr via integer id should only be used inside the DITypeAttr of the parent. Most of the complexity comes from computing the "list of parent nodes" by merging the ones from the components. This is made is such a way that the extra cost for apps without recursive derived type is minimal because the extra data structure should not require extra dynamic allocations when they are no or little recursion. Example: Take the following type graph (Fortran source for it in the added debug-cyclic-derived-type-caching-complex.f90). A is the tope level types, and has direct components of types B, C, and E. There are cycles in the type tree introduced by type B and D. Types `C` and `E` are of interest here because they are in the middle of those cycles and appear in several places in the type tree. There occurrences is labeled in brackets in the order of visit by the DebugTypeGenerator. ``` A -> B -> C [1] -> D -> E [1] -> F -> G -> B | | | | | | | | -> D | | | | | | -> H -> E [2] -> F -> G -> B | | | | | |-> D | | | | -> I -> E [3] -> F -> G -> B | | | | | |-> D | | -> C [2] | | -> C [3] -> D | -> E [4] -> F -> G -> B | | -> D ``` With this patch, E[2] and E[3] can share the same DITypeAttr as well as C[1] and C[2] while they previously all got there own nodes. To be safe with regards to cycles in MLIR, a DITypeAttr created for a node N2 under a node N1 being recursively referred to and above the recursive reference to N1 shall not be used above N1 in the DITypeAttr tree. It can however be used in several places under N1. Hence here: -E[2] cannot reuse E[1] DITypeAttr because D appears above and under E[1]. -E[3] can reuse E[2] DITypeAttr because they are both under B and above D. -E[4] cannot reuse E[3] DITypeAttr because it is above B. This is achieved by this patch because when visiting A and reaching B, the recursive reference to B is registered in the visit context. This context is added D when going back-up in F. So when reaching back E[1] with the information to build its DITypeAttr, its recursive references are known and saved along the DITypeAttr in the cache. When reaching back D, the cache for E is cleared because it is known it depended on D. A new DITypeAttr is created after E[2], and this time it only depends on B because the D under E[2] is not a recursive reference (D is not above E[2]). Hence, when reaching E[3] it can be reused, and the cache entry for E[2] is cleared when reaching B, which leads to a new DITypeAttr to be created for E[4].
1 parent b5f5a76 commit 4868d66

File tree

4 files changed

+354
-111
lines changed

4 files changed

+354
-111
lines changed

flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp

Lines changed: 127 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ DebugTypeGenerator::DebugTypeGenerator(mlir::ModuleOp m,
4848
mlir::SymbolTable *symbolTable_,
4949
const mlir::DataLayout &dl)
5050
: module(m), symbolTable(symbolTable_), dataLayout{&dl},
51-
kindMapping(getKindMapping(m)), llvmTypeConverter(m, false, false, dl),
52-
derivedTypeDepth(0) {
51+
kindMapping(getKindMapping(m)), llvmTypeConverter(m, false, false, dl) {
5352
LLVM_DEBUG(llvm::dbgs() << "DITypeAttr generator\n");
5453

5554
mlir::MLIRContext *context = module.getContext();
@@ -272,31 +271,127 @@ DebugTypeGenerator::getFieldSizeAndAlign(mlir::Type fieldTy) {
272271
return std::pair{byteSize, byteAlign};
273272
}
274273

274+
mlir::LLVM::DITypeAttr DerivedTypeCache::lookup(mlir::Type type) {
275+
auto iter = typeCache.find(type);
276+
if (iter != typeCache.end()) {
277+
if (iter->second.first) {
278+
componentActiveRecursionLevels = iter->second.second;
279+
}
280+
return iter->second.first;
281+
}
282+
return nullptr;
283+
}
284+
285+
DerivedTypeCache::ActiveLevels
286+
DerivedTypeCache::startTranslating(mlir::Type type,
287+
mlir::LLVM::DITypeAttr placeHolder) {
288+
derivedTypeDepth++;
289+
if (!placeHolder)
290+
return {};
291+
typeCache[type] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(
292+
placeHolder, {derivedTypeDepth});
293+
return {};
294+
}
295+
296+
void DerivedTypeCache::preComponentVisitUpdate() {
297+
componentActiveRecursionLevels.clear();
298+
}
299+
300+
void DerivedTypeCache::postComponentVisitUpdate(
301+
ActiveLevels &activeRecursionLevels) {
302+
if (componentActiveRecursionLevels.empty())
303+
return;
304+
ActiveLevels oldLevels;
305+
oldLevels.swap(activeRecursionLevels);
306+
std::merge(componentActiveRecursionLevels.begin(),
307+
componentActiveRecursionLevels.end(), oldLevels.begin(),
308+
oldLevels.end(), std::back_inserter(activeRecursionLevels));
309+
}
310+
311+
void DerivedTypeCache::finalize(mlir::Type ty, mlir::LLVM::DITypeAttr attr,
312+
ActiveLevels &&activeRecursionLevels) {
313+
// If there is no nested recursion or if this type does not point to any type
314+
// nodes above it, it is safe to cache it indefinitely (it can be used in any
315+
// contexts).
316+
if (activeRecursionLevels.empty() ||
317+
(activeRecursionLevels[0] == derivedTypeDepth)) {
318+
typeCache[ty] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(attr, {});
319+
componentActiveRecursionLevels.clear();
320+
cleanUpCache(derivedTypeDepth);
321+
--derivedTypeDepth;
322+
return;
323+
}
324+
// Trim any recursion below the current type.
325+
if (activeRecursionLevels.back() >= derivedTypeDepth) {
326+
auto last = llvm::find_if(activeRecursionLevels, [&](std::int32_t depth) {
327+
return depth >= derivedTypeDepth;
328+
});
329+
if (last != activeRecursionLevels.end()) {
330+
activeRecursionLevels.erase(last, activeRecursionLevels.end());
331+
}
332+
}
333+
componentActiveRecursionLevels = std::move(activeRecursionLevels);
334+
typeCache[ty] = std::pair<mlir::LLVM::DITypeAttr, ActiveLevels>(
335+
attr, componentActiveRecursionLevels);
336+
cleanUpCache(derivedTypeDepth);
337+
if (!componentActiveRecursionLevels.empty())
338+
insertCacheCleanUp(ty, componentActiveRecursionLevels.back());
339+
--derivedTypeDepth;
340+
}
341+
342+
void DerivedTypeCache::insertCacheCleanUp(mlir::Type type, int32_t depth) {
343+
auto iter = llvm::find_if(cacheCleanupList,
344+
[&](const auto &x) { return x.second >= depth; });
345+
if (iter == cacheCleanupList.end()) {
346+
cacheCleanupList.emplace_back(
347+
std::pair<llvm::SmallVector<mlir::Type>, int32_t>({type}, depth));
348+
return;
349+
}
350+
if (iter->second == depth) {
351+
iter->first.push_back(type);
352+
return;
353+
}
354+
cacheCleanupList.insert(
355+
iter, std::pair<llvm::SmallVector<mlir::Type>, int32_t>({type}, depth));
356+
}
357+
358+
void DerivedTypeCache::cleanUpCache(int32_t depth) {
359+
if (cacheCleanupList.empty())
360+
return;
361+
// cleanups are done in the post actions when visiting a derived type
362+
// tree. So if there is a clean-up for the current depth, it has to be
363+
// the last one (deeper ones must have been done already).
364+
if (cacheCleanupList.back().second == depth) {
365+
for (mlir::Type type : cacheCleanupList.back().first)
366+
typeCache[type].first = nullptr;
367+
cacheCleanupList.pop_back_n(1);
368+
}
369+
}
370+
275371
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
276372
fir::RecordType Ty, mlir::LLVM::DIFileAttr fileAttr,
277373
mlir::LLVM::DIScopeAttr scope, fir::cg::XDeclareOp declOp) {
278-
// Check if this type has already been converted.
279-
auto iter = typeCache.find(Ty);
280-
if (iter != typeCache.end())
281-
return iter->second;
282374

283-
bool canCacheThisType = true;
284-
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
375+
if (mlir::LLVM::DITypeAttr attr = derivedTypeCache.lookup(Ty))
376+
return attr;
377+
285378
mlir::MLIRContext *context = module.getContext();
286-
auto recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
379+
auto [nameKind, sourceName] = fir::NameUniquer::deconstruct(Ty.getName());
380+
if (nameKind != fir::NameUniquer::NameKind::DERIVED_TYPE)
381+
return genPlaceholderType(context);
382+
383+
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
287384
// Generate a place holder TypeAttr which will be used if a member
288385
// references the parent type.
289-
auto comAttr = mlir::LLVM::DICompositeTypeAttr::get(
386+
auto recId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
387+
auto placeHolder = mlir::LLVM::DICompositeTypeAttr::get(
290388
context, recId, /*isRecSelf=*/true, llvm::dwarf::DW_TAG_structure_type,
291389
mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope,
292390
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0,
293391
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
294392
/*allocated=*/nullptr, /*associated=*/nullptr);
295-
typeCache[Ty] = comAttr;
296-
297-
auto result = fir::NameUniquer::deconstruct(Ty.getName());
298-
if (result.first != fir::NameUniquer::NameKind::DERIVED_TYPE)
299-
return genPlaceholderType(context);
393+
DerivedTypeCache::ActiveLevels nestedRecursions =
394+
derivedTypeCache.startTranslating(Ty, placeHolder);
300395

301396
fir::TypeInfoOp tiOp = symbolTable->lookup<fir::TypeInfoOp>(Ty.getName());
302397
unsigned line = (tiOp) ? getLineFromLoc(tiOp.getLoc()) : 1;
@@ -305,6 +400,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
305400
mlir::IntegerType intTy = mlir::IntegerType::get(context, 64);
306401
std::uint64_t offset = 0;
307402
for (auto [fieldName, fieldTy] : Ty.getTypeList()) {
403+
derivedTypeCache.preComponentVisitUpdate();
308404
auto [byteSize, byteAlign] = getFieldSizeAndAlign(fieldTy);
309405
std::optional<llvm::ArrayRef<int64_t>> lowerBounds =
310406
fir::getComponentLowerBoundsIfNonDefault(Ty, fieldName, module,
@@ -317,22 +413,22 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
317413
mlir::LLVM::DITypeAttr elemTy;
318414
if (lowerBounds && seqTy &&
319415
lowerBounds->size() == seqTy.getShape().size()) {
320-
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
416+
llvm::SmallVector<mlir::LLVM::DINodeAttr> arrayElements;
321417
for (auto [bound, dim] :
322418
llvm::zip_equal(*lowerBounds, seqTy.getShape())) {
323419
auto countAttr = mlir::IntegerAttr::get(intTy, llvm::APInt(64, dim));
324420
auto lowerAttr = mlir::IntegerAttr::get(intTy, llvm::APInt(64, bound));
325421
auto subrangeTy = mlir::LLVM::DISubrangeAttr::get(
326422
context, countAttr, lowerAttr, /*upperBound=*/nullptr,
327423
/*stride=*/nullptr);
328-
elements.push_back(subrangeTy);
424+
arrayElements.push_back(subrangeTy);
329425
}
330426
elemTy = mlir::LLVM::DICompositeTypeAttr::get(
331427
context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr,
332428
/*file=*/nullptr, /*line=*/0, /*scope=*/nullptr,
333429
convertType(seqTy.getEleTy(), fileAttr, scope, declOp),
334430
mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0,
335-
elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
431+
arrayElements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
336432
/*allocated=*/nullptr, /*associated=*/nullptr);
337433
} else
338434
elemTy = convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr);
@@ -344,96 +440,37 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
344440
/*extra data=*/nullptr);
345441
elements.push_back(tyAttr);
346442
offset += llvm::alignTo(byteSize, byteAlign);
347-
348-
// Currently, the handling of recursive debug type in mlir has some
349-
// limitations that were discussed at the end of the thread for following
350-
// PR.
351-
// https://github.com/llvm/llvm-project/pull/106571
352-
//
353-
// Problem could be explained with the following example code:
354-
// type t2
355-
// type(t1), pointer :: p1
356-
// end type
357-
// type t1
358-
// type(t2), pointer :: p2
359-
// end type
360-
// In the description below, type_self means a temporary type that is
361-
// generated
362-
// as a place holder while the members of that type are being processed.
363-
//
364-
// If we process t1 first then we will have the following structure after
365-
// it has been processed.
366-
// t1 -> t2 -> t1_self
367-
// This is because when we started processing t2, we did not have the
368-
// complete t1 but its place holder t1_self.
369-
// Now if some entity requires t2, we will already have that in cache and
370-
// will return it. But this t2 refers to t1_self and not to t1. In mlir
371-
// handling, only those types are allowed to have _self reference which are
372-
// wrapped by entity whose reference it is. So t1 -> t2 -> t1_self is ok
373-
// because the t1_self reference can be resolved by the outer t1. But
374-
// standalone t2 is not because there will be no way to resolve it. Until
375-
// this is fixed in mlir, we avoid caching such types. Please see
376-
// DebugTranslation::translateRecursive for details on how mlir handles
377-
// recursive types.
378-
// The code below checks for situation where it will be unsafe to cache
379-
// a type to avoid this problem. We do that in 2 situations.
380-
// 1. If a member is record type, then its type would have been processed
381-
// before reaching here. If it is not in the cache, it means that it was
382-
// found to be unsafe to cache. So any type containing it will also not
383-
// be cached
384-
// 2. The type of the member is found in the cache but it is a place holder.
385-
// In this case, its recID should match the recID of the type we are
386-
// processing. This helps us to cache the following type.
387-
// type t
388-
// type(t), allocatable :: p
389-
// end type
390-
mlir::Type baseTy = getDerivedType(fieldTy);
391-
if (auto recTy = mlir::dyn_cast<fir::RecordType>(baseTy)) {
392-
auto iter = typeCache.find(recTy);
393-
if (iter == typeCache.end())
394-
canCacheThisType = false;
395-
else {
396-
if (auto tyAttr =
397-
mlir::dyn_cast<mlir::LLVM::DICompositeTypeAttr>(iter->second)) {
398-
if (tyAttr.getIsRecSelf() && tyAttr.getRecId() != recId)
399-
canCacheThisType = false;
400-
}
401-
}
402-
}
443+
derivedTypeCache.postComponentVisitUpdate(nestedRecursions);
403444
}
404445

405446
auto finalAttr = mlir::LLVM::DICompositeTypeAttr::get(
406447
context, recId, /*isRecSelf=*/false, llvm::dwarf::DW_TAG_structure_type,
407-
mlir::StringAttr::get(context, result.second.name), fileAttr, line, scope,
448+
mlir::StringAttr::get(context, sourceName.name), fileAttr, line, scope,
408449
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
409450
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
410451
/*allocated=*/nullptr, /*associated=*/nullptr);
411452

412-
// derivedTypeDepth == 1 means that it is a top level type which is safe to
413-
// cache.
414-
if (canCacheThisType || derivedTypeDepth == 1) {
415-
typeCache[Ty] = finalAttr;
416-
} else {
417-
auto iter = typeCache.find(Ty);
418-
if (iter != typeCache.end())
419-
typeCache.erase(iter);
420-
}
453+
derivedTypeCache.finalize(Ty, finalAttr, std::move(nestedRecursions));
454+
421455
return finalAttr;
422456
}
423457

424458
mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
425459
mlir::TupleType Ty, mlir::LLVM::DIFileAttr fileAttr,
426460
mlir::LLVM::DIScopeAttr scope, fir::cg::XDeclareOp declOp) {
427461
// Check if this type has already been converted.
428-
auto iter = typeCache.find(Ty);
429-
if (iter != typeCache.end())
430-
return iter->second;
462+
if (mlir::LLVM::DITypeAttr attr = derivedTypeCache.lookup(Ty))
463+
return attr;
464+
465+
DerivedTypeCache::ActiveLevels nestedRecursions =
466+
derivedTypeCache.startTranslating(Ty);
431467

432468
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
433469
mlir::MLIRContext *context = module.getContext();
434470

435471
std::uint64_t offset = 0;
436472
for (auto fieldTy : Ty.getTypes()) {
473+
derivedTypeCache.preComponentVisitUpdate();
437474
auto [byteSize, byteAlign] = getFieldSizeAndAlign(fieldTy);
438475
mlir::LLVM::DITypeAttr elemTy =
439476
convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr);
@@ -445,6 +482,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
445482
/*extra data=*/nullptr);
446483
elements.push_back(tyAttr);
447484
offset += llvm::alignTo(byteSize, byteAlign);
485+
derivedTypeCache.postComponentVisitUpdate(nestedRecursions);
448486
}
449487

450488
auto typeAttr = mlir::LLVM::DICompositeTypeAttr::get(
@@ -453,7 +491,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
453491
/*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8,
454492
/*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr,
455493
/*allocated=*/nullptr, /*associated=*/nullptr);
456-
typeCache[Ty] = typeAttr;
494+
derivedTypeCache.finalize(Ty, typeAttr, std::move(nestedRecursions));
457495
return typeAttr;
458496
}
459497

@@ -667,27 +705,7 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr,
667705
return convertCharacterType(charTy, fileAttr, scope, declOp,
668706
/*hasDescriptor=*/false);
669707
} else if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(Ty)) {
670-
// For nested derived types like shown below, the call sequence of the
671-
// convertRecordType will look something like as follows:
672-
// convertRecordType (t1)
673-
// convertRecordType (t2)
674-
// convertRecordType (t3)
675-
// We need to recognize when we are processing the top level type like t1
676-
// to make caching decision. The variable `derivedTypeDepth` is used for
677-
// this purpose and maintains the current depth of derived type processing.
678-
// type t1
679-
// type(t2), pointer :: p1
680-
// end type
681-
// type t2
682-
// type(t3), pointer :: p2
683-
// end type
684-
// type t2
685-
// integer a
686-
// end type
687-
derivedTypeDepth++;
688-
auto result = convertRecordType(recTy, fileAttr, scope, declOp);
689-
derivedTypeDepth--;
690-
return result;
708+
return convertRecordType(recTy, fileAttr, scope, declOp);
691709
} else if (auto tupleTy = mlir::dyn_cast_if_present<mlir::TupleType>(Ty)) {
692710
return convertTupleType(tupleTy, fileAttr, scope, declOp);
693711
} else if (auto refTy = mlir::dyn_cast_if_present<fir::ReferenceType>(Ty)) {

0 commit comments

Comments
 (0)