Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2fb2c39

Browse files
committed
extract "extractDomainToIds" from extractMappingToThreads
The core functionality for extracting mapped affine expressions can be reused for individual mapping ids or other types of mapping ids, e.g., block ids. Therefore, it should live with other schedule operations.
1 parent 49a2965 commit 2fb2c39

File tree

3 files changed

+49
-21
lines changed

3 files changed

+49
-21
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -476,28 +476,12 @@ constexpr auto kWarp = "warp";
476476
isl::multi_union_pw_aff extractDomainToThread(
477477
const detail::ScheduleTree* tree,
478478
size_t nThread) {
479-
using namespace polyhedral::detail;
480-
481-
auto space = isl::space(tree->ctx_, 0);
482-
auto empty = isl::union_set::empty(space);
483-
auto id = isl::id(tree->ctx_, kBlock);
484-
space = space.named_set_from_params_id(id, nThread);
485-
auto zero = isl::multi_val::zero(space);
486-
auto domainToThread = isl::multi_union_pw_aff(empty, zero);
487-
488-
for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) {
489-
auto mappingNode = mapping->elemAs<ScheduleTreeElemMappingFilter>();
490-
auto list = isl::union_pw_aff_list(tree->ctx_, nThread);
491-
for (size_t i = 0; i < nThread; ++i) {
492-
auto threadId = mapping::ThreadId::makeId(i);
493-
auto threadMap = mappingNode->mapping.at(threadId);
494-
list = list.add(threadMap);
495-
}
496-
auto nodeToThread = isl::multi_union_pw_aff(space, list);
497-
domainToThread = domainToThread.union_add(nodeToThread);
479+
std::vector<mapping::MappingId> ids;
480+
for (size_t i = 0; i < nThread; ++i) {
481+
ids.emplace_back(mapping::ThreadId::makeId(i));
498482
}
499-
500-
return domainToThread;
483+
auto tupleId = isl::id(tree->ctx_, kBlock);
484+
return extractDomainToIds(tree, ids, tupleId);
501485
}
502486

503487
/*

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "tc/core/check.h"
3232
#include "tc/core/constants.h"
3333
#include "tc/core/polyhedral/functional.h"
34+
#include "tc/core/polyhedral/mapping_types.h"
3435
#include "tc/core/polyhedral/schedule_tree_elem.h"
3536
#include "tc/core/polyhedral/schedule_tree_matcher.h"
3637
#include "tc/core/scope_guard.h"
@@ -798,5 +799,38 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
798799
parent->insertChild(childPos, std::move(seq));
799800
}
800801

802+
/*
803+
* Extract a mapping from the domain elements active at "tree"
804+
* to identifiers "ids", where all branches in "tree"
805+
* are assumed to have been mapped to these identifiers.
806+
* The result lives in a space of the form "tupleId"["ids"...].
807+
*/
808+
isl::multi_union_pw_aff extractDomainToIds(
809+
const detail::ScheduleTree* tree,
810+
const std::vector<mapping::MappingId>& ids,
811+
isl::id tupleId) {
812+
using namespace polyhedral::detail;
813+
814+
auto space = isl::space(tree->ctx_, 0);
815+
auto empty = isl::union_set::empty(space);
816+
space = space.named_set_from_params_id(tupleId, ids.size());
817+
auto zero = isl::multi_val::zero(space);
818+
auto domainToIds = isl::multi_union_pw_aff(empty, zero);
819+
820+
for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) {
821+
auto mappingNode = mapping->elemAs<ScheduleTreeElemMappingFilter>();
822+
auto list = isl::union_pw_aff_list(tree->ctx_, ids.size());
823+
for (auto id : ids) {
824+
CHECK_GT(mappingNode->mapping.count(id), 0) << "no mapping to id " << id;
825+
auto idMap = mappingNode->mapping.at(id);
826+
list = list.add(idMap);
827+
}
828+
auto nodeToIds = isl::multi_union_pw_aff(space, list);
829+
domainToIds = domainToIds.union_add(nodeToIds);
830+
}
831+
832+
return domainToIds;
833+
}
834+
801835
} // namespace polyhedral
802836
} // namespace tc

tc/core/polyhedral/schedule_transforms.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <vector>
2323

2424
#include "tc/core/polyhedral/functional.h"
25+
#include "tc/core/polyhedral/mapping_types.h"
2526
#include "tc/core/polyhedral/options.h"
2627
#include "tc/core/polyhedral/schedule_tree.h"
2728
#include "tc/external/isl.h"
@@ -325,6 +326,15 @@ isl::union_set activeDomainPointsBelow(
325326
const detail::ScheduleTree* root,
326327
const detail::ScheduleTree* node);
327328

329+
// Extract a mapping from the domain elements active at "tree"
330+
// to identifiers "ids", where all branches in "tree"
331+
// are assumed to have been mapped to these identifiers.
332+
// The result lives in a space of the form "tupleId"["ids"...].
333+
isl::multi_union_pw_aff extractDomainToIds(
334+
const detail::ScheduleTree* tree,
335+
const std::vector<mapping::MappingId>& ids,
336+
isl::id tupleId);
337+
328338
} // namespace polyhedral
329339
} // namespace tc
330340

0 commit comments

Comments
 (0)