Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 53e6185

Browse files
authored
Merge pull request #405 from facebookresearch/pr/space
simplify space manipulations
2 parents 621901d + 25a014a commit 53e6185

File tree

6 files changed

+20
-21
lines changed

6 files changed

+20
-21
lines changed

tc/core/halide2isl.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,15 @@ isl::map extractAccess(
258258

259259
isl::space domainSpace = domain.get_space();
260260
isl::space paramSpace = domainSpace.params();
261-
262-
isl::space rangeSpace = paramSpace.add_dims(isl::dim_type::set, args.size());
263-
264-
rangeSpace = rangeSpace.set_tuple_name(isl::dim_type::set, tensor);
261+
isl::id tensorID(paramSpace.get_ctx(), tensor);
262+
auto rangeSpace = paramSpace.named_set_from_params_id(tensorID, args.size());
265263

266264
// Add a tag to the domain space so that we can maintain a mapping
267265
// between each access in the IR and the reads/writes maps.
268266
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
269267
isl::id tagID(domain.get_ctx(), tag);
270268
accesses->emplace(op, tagID);
271-
isl::space tagSpace = paramSpace.set_tuple_name(isl::dim_type::set, tag);
269+
isl::space tagSpace = paramSpace.named_set_from_params_id(tagID, 0);
272270
domainSpace = domainSpace.product(tagSpace);
273271

274272
// Start with a totally unconstrained relation - every point in

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,8 @@ isl::multi_aff makeMultiAffAccess(
481481
CHECK_NE(subscripts.size(), 0u) << "cannot build subscript aff for a scalar";
482482

483483
auto domainSpace = findDomainSpaceById(context);
484-
auto tensorSpace = domainSpace.params().set_from_params().add_dims(
485-
isl::dim_type::set, subscripts.size());
486-
tensorSpace = tensorSpace.set_tuple_id(isl::dim_type::set, tensorId);
484+
auto tensorSpace = domainSpace.params().named_set_from_params_id(
485+
tensorId, subscripts.size());
487486
auto space = domainSpace.map_from_domain_and_range(tensorSpace);
488487

489488
auto ma = isl::multi_aff::zero(space);

tc/core/polyhedral/memory_promotion.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ ScopedFootprint outputRanges(isl::map access) {
9898
}
9999
return footprint;
100100
}
101+
102+
// Given a set space, construct a map space with the input as domain and
103+
// a range of the given size.
104+
isl::space add_range(isl::space space, unsigned dim) {
105+
auto range = space.params().unnamed_set_from_params(dim);
106+
return space.map_from_domain_and_range(range);
107+
}
108+
101109
} // namespace
102110

103111
// Access has the shape :: [D -> ref] -> O
@@ -128,8 +136,7 @@ std::unique_ptr<TensorReferenceGroup> TensorReferenceGroup::makeSingleton(
128136
}
129137

130138
isl::set ScopedFootprint::footprint(isl::set domain) const {
131-
auto space = domain.get_space().from_domain();
132-
space = space.add_dims(isl::dim_type::out, size());
139+
auto space = add_range(domain.get_space(), size());
133140
auto accessed = isl::map::universe(space).intersect_domain(domain);
134141
auto lspace = isl::local_space(accessed.get_space().range());
135142

@@ -147,8 +154,7 @@ isl::multi_aff ScopedFootprint::lowerBounds() const {
147154
if (size() == 0) {
148155
throw promotion::PromotionNYI("promotion for scalars");
149156
}
150-
auto space = at(0).lowerBound.get_space();
151-
space = space.add_dims(isl::dim_type::out, size() - 1);
157+
auto space = add_range(at(0).lowerBound.get_space().domain(), size());
152158
auto ma = isl::multi_aff::zero(space);
153159

154160
int i = 0;
@@ -402,8 +408,7 @@ isl::set tensorElementsSet(const Scop& scop, isl::id tensorId) {
402408
auto halideParameter = scop.findArgument(tensorId).parameter();
403409
auto space = scop.domain().get_space().params();
404410
auto nDim = halideParameter.dimensions();
405-
space = space.add_dims(isl::dim_type::set, nDim)
406-
.set_tuple_id(isl::dim_type::set, tensorId);
411+
space = space.named_set_from_params_id(tensorId, nDim);
407412

408413
auto tensorElements = isl::set::universe(space);
409414
for (int i = 0; i < nDim; ++i) {
@@ -443,7 +448,7 @@ isl::multi_aff dropDummyTensorDimensions(
443448
}
444449
}
445450

446-
space = space.from_domain().add_dims(isl::dim_type::out, list.n());
451+
space = add_range(space, list.n());
447452
return isl::multi_aff(space, list);
448453
}
449454
} // namespace

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,7 @@ detail::ScheduleTree* insertEmptyExtensionAbove(
559559
isl::map labelExtension(ScheduleTree* root, ScheduleTree* tree, isl::id id) {
560560
auto prefix = prefixScheduleMupa(root, tree);
561561
auto scheduleSpace = prefix.get_space();
562-
auto space = scheduleSpace.params().set_from_params().set_tuple_id(
563-
isl::dim_type::set, id);
562+
auto space = scheduleSpace.params().named_set_from_params_id(id, 0);
564563
auto extensionSpace = scheduleSpace.map_from_domain_and_range(space);
565564
return isl::map::universe(extensionSpace);
566565
}

tc/core/polyhedral/scop.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,13 @@ isl::aff Scop::makeIslAffFromStmtExpr(
525525
const Halide::Expr& e) const {
526526
auto ctx = stmtId.get_ctx();
527527
auto iterators = halide.iterators.at(stmtId);
528-
auto space = paramSpace.set_from_params();
529-
space = space.add_dims(isl::dim_type::set, iterators.size());
528+
auto space = paramSpace.named_set_from_params_id(stmtId, iterators.size());
530529
// Set the names of the set dimensions of "space" for use
531530
// by halide2isl::makeIslAffFromExpr.
532531
for (size_t i = 0; i < iterators.size(); ++i) {
533532
isl::id id(ctx, iterators[i]);
534533
space = space.set_dim_id(isl::dim_type::set, i, id);
535534
}
536-
space = space.set_tuple_id(isl::dim_type::set, stmtId);
537535
return halide2isl::makeIslAffFromExpr(space, e);
538536
}
539537

third-party/islpp

Submodule islpp updated from e5b7b42 to 72af18e

0 commit comments

Comments
 (0)