21
21
#include " tc/core/check.h"
22
22
#include " tc/core/constants.h"
23
23
#include " tc/core/polyhedral/body.h"
24
+ #include " tc/core/polyhedral/domain_types.h"
24
25
#include " tc/core/polyhedral/schedule_isl_conversion.h"
25
26
#include " tc/core/polyhedral/schedule_transforms.h"
26
27
#include " tc/core/polyhedral/schedule_tree.h"
@@ -80,14 +81,14 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
80
81
return builder.table ;
81
82
}
82
83
83
- isl::aff makeIslAffFromInt (isl::space space, int64_t val) {
84
+ isl::AffOn<> makeIslAffFromInt (isl::Space<> space, int64_t val) {
84
85
isl::val v = isl::val (space.get_ctx (), val);
85
- return isl::aff (isl::local_space (space), v);
86
+ return isl::AffOn<>( isl:: aff (isl::local_space (space), v) );
86
87
}
87
88
88
- std::vector<isl::aff > makeIslAffBoundsFromExpr (
89
- isl::space space,
90
- const Expr& e,
89
+ std::vector<isl::AffOn<> > makeIslAffBoundsFromExpr (
90
+ isl::Space<> space,
91
+ const Halide:: Expr& e,
91
92
bool allowMin,
92
93
bool allowMax);
93
94
@@ -101,9 +102,9 @@ namespace {
101
102
* x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
102
103
*/
103
104
template <typename T>
104
- inline std::vector<isl::aff >
105
- concatAffs (isl::space space, T op, bool allowMin, bool allowMax) {
106
- std::vector<isl::aff > result;
105
+ inline std::vector<isl::AffOn<> >
106
+ concatAffs (isl::Space<> space, T op, bool allowMin, bool allowMax) {
107
+ std::vector<isl::AffOn<> > result;
107
108
108
109
for (const auto & aff :
109
110
makeIslAffBoundsFromExpr (space, op->a , allowMin, allowMax)) {
@@ -129,10 +130,10 @@ concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
129
130
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
130
131
*/
131
132
template <typename T>
132
- inline std::vector<isl::aff > combineSingleAffs (
133
- isl::space space,
133
+ inline std::vector<isl::AffOn<> > combineSingleAffs (
134
+ isl::Space<> space,
134
135
T op,
135
- isl::aff (isl::aff ::*combine)(isl::aff ) const ) {
136
+ isl::AffOn<> (isl::AffOn<> ::*combine)(const isl::AffOn<>& ) const ) {
136
137
auto left = makeIslAffBoundsFromExpr (space, op->a , false , false );
137
138
auto right = makeIslAffBoundsFromExpr (space, op->b , false , false );
138
139
TC_CHECK_LE (left.size (), 1u );
@@ -162,9 +163,9 @@ inline std::vector<isl::aff> combineSingleAffs(
162
163
* If a Halide expression cannot be converted into a list of affine expressions,
163
164
* return an empty list.
164
165
*/
165
- std::vector<isl::aff > makeIslAffBoundsFromExpr (
166
- isl::space space,
167
- const Expr& e,
166
+ std::vector<isl::AffOn<> > makeIslAffBoundsFromExpr (
167
+ isl::Space<> space,
168
+ const Halide:: Expr& e,
168
169
bool allowMin,
169
170
bool allowMax) {
170
171
TC_CHECK (!(allowMin && allowMax));
@@ -178,7 +179,7 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
178
179
if (const Variable* op = e.as <Variable>()) {
179
180
isl::id id (space.get_ctx (), op->name );
180
181
if (space.has_param (id)) {
181
- return {isl::aff ::param_on_domain_space (space, id)};
182
+ return {isl::AffOn<> ::param_on_domain_space (space, id)};
182
183
}
183
184
LOG (FATAL) << " Variable not found in isl::space: " << space << " : " << op
184
185
<< " : " << op->name << ' \n ' ;
@@ -188,13 +189,13 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
188
189
} else if (maxOp != nullptr && allowMax) {
189
190
return concatAffs (space, maxOp, allowMin, allowMax);
190
191
} else if (const Add* op = e.as <Add>()) {
191
- return combineSingleAffs (space, op, &isl::aff ::add);
192
+ return combineSingleAffs (space, op, &isl::AffOn<> ::add);
192
193
} else if (const Sub* op = e.as <Sub>()) {
193
- return combineSingleAffs (space, op, &isl::aff ::sub);
194
+ return combineSingleAffs (space, op, &isl::AffOn<> ::sub);
194
195
} else if (const Mul* op = e.as <Mul>()) {
195
- return combineSingleAffs (space, op, &isl::aff ::mul);
196
+ return combineSingleAffs (space, op, &isl::AffOn<> ::mul);
196
197
} else if (const Div* op = e.as <Div>()) {
197
- return combineSingleAffs (space, op, &isl::aff ::div);
198
+ return combineSingleAffs (space, op, &isl::AffOn<> ::div);
198
199
} else if (const Mod* op = e.as <Mod>()) {
199
200
std::vector<isl::aff> result;
200
201
// We cannot span multiple constraints if a modulo operation is involved.
@@ -211,45 +212,50 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
211
212
return {};
212
213
}
213
214
214
- isl::aff makeIslAffFromExpr (isl::space space, const Expr& e) {
215
+ isl::AffOn<> makeIslAffFromExpr (isl::Space<> space, const Halide:: Expr& e) {
215
216
auto list = makeIslAffBoundsFromExpr (space, e, false , false );
216
217
TC_CHECK_LE (list.size (), 1u )
217
218
<< " Halide expr " << e << " unrolled into more than 1 isl aff"
218
219
<< " but min/max operations were disabled" ;
219
220
220
221
// Non-affine
221
222
if (list.size () == 0 ) {
222
- return isl::aff ();
223
+ return isl::AffOn<> ();
223
224
}
224
225
return list[0 ];
225
226
}
226
227
227
- isl::space makeParamSpace (isl::ctx ctx, const ParameterVector& params) {
228
+ isl::Space<> makeParamSpace (isl::ctx ctx, const ParameterVector& params) {
228
229
auto space = isl::space (ctx, 0 );
229
230
// set parameter names
230
231
for (auto p : params) {
231
232
space = space.add_param (isl::id (ctx, p.name ()));
232
233
}
233
- return space;
234
+ return isl::Space<>( space) ;
234
235
}
235
236
236
- isl::set makeParamContext (isl::ctx ctx, const ParameterVector& params) {
237
+ isl::Set<> makeParamContext (isl::ctx ctx, const ParameterVector& params) {
237
238
auto space = makeParamSpace (ctx, params);
238
- auto context = isl::set ::universe (space);
239
+ auto context = isl::Set<> ::universe (space);
239
240
for (auto p : params) {
240
- isl::aff a (isl::aff ::param_on_domain_space (space, isl::id (ctx, p.name ())));
241
- context = context & (a >= 0 );
241
+ auto a (isl::AffOn<> ::param_on_domain_space (space, isl::id (ctx, p.name ())));
242
+ context = context & isl::PwAffOn<>(a). nonneg_set ( );
242
243
}
243
244
return context;
244
245
}
245
246
246
247
namespace {
247
248
248
- isl::map extractAccess (
249
+ template <typename Domain, typename Range>
250
+ static isl::MultiAff<isl::Pair<Domain, Range>, Domain> domainMap (isl::Space<Domain, Range> space) {
251
+ return isl::MultiAff<isl::Pair<Domain, Range>, Domain>::domain_map (space);
252
+ }
253
+
254
+ isl::Map<isl::Pair<Statement, Tag>, Tensor> extractAccess (
249
255
const IterationDomain& domain,
250
256
const IRNode* op,
251
257
const std::string& tensor,
252
- const std::vector<Expr>& args,
258
+ const std::vector<Halide:: Expr>& args,
253
259
AccessMap* accesses) {
254
260
// Make an isl::map representing this access. It maps from the iteration space
255
261
// to the tensor's storage space, using the coordinates accessed.
@@ -258,16 +264,16 @@ isl::map extractAccess(
258
264
// to the outer loop iterators) and then convert this set
259
265
// into a map in terms of the iteration domain.
260
266
261
- auto paramSpace = isl::Space<>( domain.paramSpace ) ;
267
+ auto paramSpace = domain.paramSpace ;
262
268
isl::id tensorID (paramSpace.get_ctx (), tensor);
263
269
auto tensorTuple = constructTensorTuple (paramSpace, tensorID, args.size ());
264
270
auto tensorSpace = tensorTuple.get_space ();
265
271
266
272
// Start with a totally unconstrained set - every point in
267
273
// the allocation could be accessed.
268
- isl::set access = isl::set ::universe (tensorSpace);
274
+ auto access = isl::Set<Tensor> ::universe (tensorSpace);
269
275
270
- auto identity = isl::multi_aff ::identity (tensorSpace.map_from_set ());
276
+ auto identity = isl::MultiAff<Tensor, Tensor> ::identity (tensorSpace.map_from_set ());
271
277
for (size_t i = 0 ; i < args.size (); i++) {
272
278
// Then add one equality constraint per dimension to encode the
273
279
// point in the allocation actually read/written for each point in
@@ -279,8 +285,8 @@ isl::map extractAccess(
279
285
// ... equals the coordinate accessed as a function of the parameters.
280
286
auto domainPoint = halide2isl::makeIslAffFromExpr (paramSpace, args[i]);
281
287
if (!domainPoint.is_null ()) {
282
- domainPoint = domainPoint.unbind_params_insert_domain (tensorTuple);
283
- access = access.intersect (domainPoint .eq_set (rangePoint));
288
+ auto domainPoint2 = domainPoint.unbind_params_insert_domain (tensorTuple);
289
+ access = access.intersect (domainPoint2 .eq_set (rangePoint));
284
290
}
285
291
}
286
292
@@ -292,15 +298,13 @@ isl::map extractAccess(
292
298
std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
293
299
isl::id tagID (domain.paramSpace .get_ctx (), tag);
294
300
accesses->emplace (op, tagID);
295
- isl::space domainSpace = map.get_space ().domain ();
296
- isl::space tagSpace = domainSpace.params ().add_named_tuple_id_ui (tagID, 0 );
297
- domainSpace = domainSpace.product (tagSpace).unwrap ();
298
- map = map.preimage_domain (isl::multi_aff::domain_map (domainSpace));
299
-
300
- return map;
301
+ auto domainSpace = map.get_space ().domain ();
302
+ auto tagSpace = domainSpace.params ().add_named_tuple_id_ui <Tag>(tagID, 0 );
303
+ auto taggedSpace = domainSpace.product (tagSpace).unwrap <Statement, Tag>();
304
+ return map.preimage_domain (domainMap (taggedSpace));
301
305
}
302
306
303
- std::pair<isl::union_map, isl::union_map > extractAccesses (
307
+ std::pair<isl::UnionMap<isl::Pair<Statement, Tag>, Tensor>, isl::UnionMap<isl::Pair<Statement, Tag>, Tensor> > extractAccesses (
304
308
const IterationDomain& domain,
305
309
const Stmt& s,
306
310
AccessMap* accesses) {
@@ -325,7 +329,7 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
325
329
AccessMap* accesses;
326
330
327
331
public:
328
- isl::union_map reads, writes;
332
+ isl::UnionMap<isl::Pair<Statement, Tag>, Tensor> reads, writes;
329
333
330
334
FindAccesses (const IterationDomain& domain, AccessMap* accesses)
331
335
: domain(domain),
@@ -355,24 +359,24 @@ bool isReductionUpdate(const Provide* op) {
355
359
* then converted into an expression on that iteration domain
356
360
* by reinterpreting the parameters as input dimensions.
357
361
*/
358
- static isl::multi_aff mapToOther (
362
+ template <typename Other>
363
+ static isl::MultiAff<Statement, Other> mapToOther (
359
364
const IterationDomain& iterationDomain,
360
365
std::unordered_set<std::string> skip,
361
366
isl::id id) {
362
367
auto ctx = iterationDomain.tuple .get_ctx ();
363
- auto list = isl::aff_list (ctx, 0 );
368
+ auto list = isl::AffListOn<Statement>( isl:: aff_list (ctx, 0 ) );
364
369
for (auto id : iterationDomain.tuple .get_id_list ()) {
365
370
if (skip.count (id.get_name ()) == 1 ) {
366
371
continue ;
367
372
}
368
- auto aff = isl::aff::param_on_domain_space (iterationDomain.paramSpace , id);
369
- aff = aff.unbind_params_insert_domain (iterationDomain.tuple );
370
- list = list.add (aff);
373
+ auto aff = isl::AffOn<>::param_on_domain_space (iterationDomain.paramSpace , id);
374
+ list = list.add (aff.unbind_params_insert_domain (iterationDomain.tuple ));
371
375
}
372
376
auto domainSpace = iterationDomain.tuple .get_space ();
373
- auto space = domainSpace.params ().add_named_tuple_id_ui (id, list.size ());
374
- space = domainSpace.product (space).unwrap ();
375
- return isl::multi_aff (space , list);
377
+ auto space = domainSpace.params ().add_named_tuple_id_ui <Other> (id, list.size ());
378
+ auto productSpace = domainSpace.product (space).template unwrap <Statement, Other> ();
379
+ return isl::MultiAff<Statement, Other>(productSpace , list);
376
380
}
377
381
378
382
/*
@@ -392,7 +396,7 @@ static isl::multi_aff mapToOther(
392
396
* that all statement instances that belong to the same reduction
393
397
* write to the same tensor element.
394
398
*/
395
- isl::union_map extractReduction (
399
+ isl::UnionMap<Statement, Reduction> extractReduction (
396
400
const IterationDomain& iterationDomain,
397
401
const Provide* op,
398
402
size_t index) {
@@ -409,16 +413,19 @@ isl::union_map extractReduction(
409
413
} finder;
410
414
411
415
if (!isReductionUpdate (op)) {
412
- return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
416
+ auto space = iterationDomain.tuple .get_space ().params ();
417
+ return isl::UnionMap<Statement, Reduction>::empty (space);
413
418
}
414
419
op->accept (&finder);
415
420
if (finder.reductionVars .size () == 0 ) {
416
- return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
421
+ auto space = iterationDomain.tuple .get_space ().params ();
422
+ return isl::UnionMap<Statement, Reduction>(isl::union_map::empty (space));
417
423
}
418
424
auto ctx = iterationDomain.tuple .get_ctx ();
419
425
isl::id id (ctx, kReductionLabel + op->name + " _" + std::to_string (index));
420
- auto reduction = mapToOther (iterationDomain, finder.reductionVars , id);
421
- return isl::union_map (isl::map (reduction));
426
+ auto reduction = mapToOther<Reduction>(iterationDomain, finder.reductionVars , id);
427
+ return isl::UnionMap<Statement, Reduction>(
428
+ isl::Map<Statement, Reduction>(reduction));
422
429
}
423
430
424
431
/*
@@ -458,7 +465,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
458
465
*/
459
466
isl::schedule makeScheduleTreeHelper (
460
467
const Stmt& s,
461
- isl::set set,
468
+ isl::Set<> set,
462
469
isl::id_list outer,
463
470
Body* body,
464
471
AccessMap* accesses,
@@ -472,7 +479,7 @@ isl::schedule makeScheduleTreeHelper(
472
479
473
480
// Construct a variable (affine function) that references
474
481
// the new parameter.
475
- auto loopVar = isl::aff ::param_on_domain_space (space, id);
482
+ auto loopVar = isl::AffOn<> ::param_on_domain_space (space, id);
476
483
477
484
// Then we add our new loop bound constraints.
478
485
auto lbs =
@@ -483,7 +490,7 @@ isl::schedule makeScheduleTreeHelper(
483
490
set = set.intersect (loopVar.ge_set (lb));
484
491
}
485
492
486
- Expr max = simplify (op->min + op->extent - 1 );
493
+ Halide:: Expr max = simplify (op->min + op->extent - 1 );
487
494
auto ubs = halide2isl::makeIslAffBoundsFromExpr (space, max, true , false );
488
495
TC_CHECK_GT (ubs.size (), 0u )
489
496
<< " could not obtain polyhedral upper bounds from " << max;
@@ -527,16 +534,16 @@ isl::schedule makeScheduleTreeHelper(
527
534
size_t stmtIndex = statements->size ();
528
535
isl::id id (set.get_ctx (), kStatementLabel + std::to_string (stmtIndex));
529
536
statements->emplace (id, op);
530
- auto tupleSpace = isl::space (set.get_ctx (), 0 );
531
- tupleSpace = tupleSpace .add_named_tuple_id_ui (id, outer.size ());
537
+ auto space = isl::Space<>( isl:: space (set.get_ctx (), 0 ) );
538
+ auto tupleSpace = space .add_named_tuple_id_ui <Statement> (id, outer.size ());
532
539
IterationDomain iterationDomain;
533
540
iterationDomain.paramSpace = set.get_space ();
534
- iterationDomain.tuple = isl::multi_id (tupleSpace, outer);
541
+ iterationDomain.tuple = isl::MultiId<Statement> (tupleSpace, outer);
535
542
domains->emplace (id, iterationDomain);
536
543
auto domain = set.unbind_params (iterationDomain.tuple );
537
544
schedule = isl::schedule::from_domain (domain);
538
545
539
- isl::union_map newReads, newWrites;
546
+ isl::UnionMap<isl::Pair<Statement, Tag>, Tensor> newReads, newWrites;
540
547
std::tie (newReads, newWrites) =
541
548
extractAccesses (iterationDomain, op, accesses);
542
549
// A tensor may be involved in multiple reductions.
@@ -553,7 +560,7 @@ isl::schedule makeScheduleTreeHelper(
553
560
return schedule;
554
561
};
555
562
556
- ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
563
+ ScheduleTreeAndAccesses makeScheduleTree (isl::Space<> paramSpace, const Stmt& s) {
557
564
ScheduleTreeAndAccesses result;
558
565
559
566
Body body (paramSpace);
@@ -562,7 +569,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
562
569
isl::id_list outer (paramSpace.get_ctx (), 0 );
563
570
auto schedule = makeScheduleTreeHelper (
564
571
s,
565
- isl::set ::universe (paramSpace),
572
+ isl::Set<> ::universe (paramSpace),
566
573
outer,
567
574
&body,
568
575
&result.accesses ,
0 commit comments