20
20
21
21
#include " tc/core/check.h"
22
22
#include " tc/core/constants.h"
23
+ #include " tc/core/polyhedral/body.h"
23
24
#include " tc/core/polyhedral/schedule_isl_conversion.h"
24
25
#include " tc/core/polyhedral/schedule_transforms.h"
25
26
#include " tc/core/polyhedral/schedule_tree.h"
@@ -333,6 +334,90 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
333
334
return {finder.reads , finder.writes };
334
335
}
335
336
337
+ bool isReductionUpdate (const Provide* op) {
338
+ if (const Call* call = op->values [0 ].as <Call>()) {
339
+ return call->is_intrinsic (tc2halide::kReductionUpdate );
340
+ } else {
341
+ return false ;
342
+ }
343
+ }
344
+
345
+ /* Construct a multi-dimensional affine function mapping
346
+ * the given iteration domain
347
+ * to the outer loop iterators that do not appear in "skip".
348
+ * "id" is used as the identifier of the target space.
349
+ * For each of these outer loop iterators, an affine function
350
+ * is first constructed in terms of the parameter space
351
+ * active at the point where the iteration domain was created and
352
+ * then converted into an expression on that iteration domain
353
+ * by reinterpreting the parameters as input dimensions.
354
+ */
355
+ static isl::multi_aff mapToOther (
356
+ const IterationDomain& iterationDomain,
357
+ std::unordered_set<std::string> skip,
358
+ isl::id id) {
359
+ auto ctx = iterationDomain.tuple .get_ctx ();
360
+ auto list = isl::aff_list (ctx, 0 );
361
+ for (auto id : iterationDomain.tuple .get_id_list ()) {
362
+ if (skip.count (id.get_name ()) == 1 ) {
363
+ continue ;
364
+ }
365
+ auto aff = isl::aff::param_on_domain_space (iterationDomain.paramSpace , id);
366
+ aff = aff.unbind_params_insert_domain (iterationDomain.tuple );
367
+ list = list.add (aff);
368
+ }
369
+ auto domainSpace = iterationDomain.tuple .get_space ();
370
+ auto space = domainSpace.params ().named_set_from_params_id (id, list.size ());
371
+ space = domainSpace.product (space).unwrap ();
372
+ return isl::multi_aff (space, list);
373
+ }
374
+
375
+ /*
376
+ * If "op" performs a reduction, then return a mapping from
377
+ * the statement instances to the individual reductions.
378
+ * Otherwise, return an empty isl::union_map.
379
+ *
380
+ * "op" is considered to be a reduction if it has been marked
381
+ * as performing a reduction and if more than one statement instance
382
+ * is involved in the individual reductions.
383
+ *
384
+ * The space of the reduction has a name of the form R_<op->name>_<index>.
385
+ * Each reduction is indexed by the outer loop variables
386
+ * that are not marked as reduction variables.
387
+ * Since the loop variables that iterate over output tensor elements
388
+ * are never marked as reduction variables, this means in particular
389
+ * that all statement instances that belong to the same reduction
390
+ * write to the same tensor element.
391
+ */
392
+ isl::union_map extractReduction (
393
+ const IterationDomain& iterationDomain,
394
+ const Provide* op,
395
+ size_t index) {
396
+ class FindReductionVars : public IRVisitor {
397
+ void visit (const Variable* op) {
398
+ if (op->reduction_domain .defined ()) {
399
+ reductionVars.insert (op->name );
400
+ }
401
+ }
402
+
403
+ public:
404
+ // The variables that are known to be reduction variables.
405
+ std::unordered_set<std::string> reductionVars;
406
+ } finder;
407
+
408
+ if (!isReductionUpdate (op)) {
409
+ return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
410
+ }
411
+ op->accept (&finder);
412
+ if (finder.reductionVars .size () == 0 ) {
413
+ return isl::union_map::empty (iterationDomain.tuple .get_space ().params ());
414
+ }
415
+ auto ctx = iterationDomain.tuple .get_ctx ();
416
+ isl::id id (ctx, kReductionLabel + op->name + " _" + std::to_string (index));
417
+ auto reduction = mapToOther (iterationDomain, finder.reductionVars , id);
418
+ return isl::union_map (isl::map (reduction));
419
+ }
420
+
336
421
/*
337
422
* Take a parametric expression "f" and convert it into an expression
338
423
* on the iteration domains in "domain" by reinterpreting the parameters
@@ -360,7 +445,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
360
445
* from outermost to innermost.
361
446
* Return the schedule corresponding to the subtree at "s".
362
447
*
363
- * "reads" and "writes" collect the accesses found along the way.
448
+ * "body" collects the accesses and reductions found along the way.
364
449
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
365
450
* (for the writes) to the corresponding tag in the access relations.
366
451
* "statements" collects the mapping from instance set tuple identifiers
@@ -372,8 +457,7 @@ isl::schedule makeScheduleTreeHelper(
372
457
const Stmt& s,
373
458
isl::set set,
374
459
isl::id_list outer,
375
- isl::union_map* reads,
376
- isl::union_map* writes,
460
+ Body* body,
377
461
AccessMap* accesses,
378
462
StatementMap* statements,
379
463
IterationDomainMap* domains) {
@@ -406,19 +490,19 @@ isl::schedule makeScheduleTreeHelper(
406
490
407
491
// Recursively descend.
408
492
auto outerNext = outer.add (isl::id (set.get_ctx (), op->name ));
409
- auto body = makeScheduleTreeHelper (
410
- op->body , set, outerNext, reads, writes , accesses, statements, domains);
493
+ auto bodySchedule = makeScheduleTreeHelper (
494
+ op->body , set, outerNext, body , accesses, statements, domains);
411
495
412
496
// Create an affine function that defines an ordering for all
413
497
// the statements in the body of this loop over the values of
414
498
// this loop. Start from a parametric expression equal
415
499
// to the current loop iterator and then convert it to
416
500
// a function on the statements in the domain of the body schedule.
417
501
auto aff = isl::aff::param_on_domain_space (space, id);
418
- auto domain = body .get_domain ();
502
+ auto domain = bodySchedule .get_domain ();
419
503
auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
420
504
421
- schedule = body .insert_partial_schedule (mupa);
505
+ schedule = bodySchedule .insert_partial_schedule (mupa);
422
506
} else if (auto op = s.as <Halide::Internal::Block>()) {
423
507
std::vector<Stmt> stmts;
424
508
stmts.push_back (op->first );
@@ -429,7 +513,7 @@ isl::schedule makeScheduleTreeHelper(
429
513
std::vector<isl::schedule> schedules;
430
514
for (Stmt stmt : stmts) {
431
515
schedules.push_back (makeScheduleTreeHelper (
432
- stmt, set, outer, reads, writes , accesses, statements, domains));
516
+ stmt, set, outer, body , accesses, statements, domains));
433
517
}
434
518
schedule = schedules[0 ].sequence (schedules[1 ]);
435
519
@@ -452,9 +536,13 @@ isl::schedule makeScheduleTreeHelper(
452
536
isl::union_map newReads, newWrites;
453
537
std::tie (newReads, newWrites) =
454
538
extractAccesses (iterationDomain, op, accesses);
539
+ // A tensor may be involved in multiple reductions.
540
+ // Use the statement index to differentiate between them.
541
+ auto newReduction = extractReduction (iterationDomain, op, stmtIndex);
455
542
456
- *reads = reads->unite (newReads);
457
- *writes = writes->unite (newWrites);
543
+ body->reads = body->reads .unite (newReads);
544
+ body->writes = body->writes .unite (newWrites);
545
+ body->reductions = body->reductions .unite (newReduction);
458
546
459
547
} else {
460
548
LOG (FATAL) << " Unhandled Halide stmt: " << s;
@@ -465,87 +553,24 @@ isl::schedule makeScheduleTreeHelper(
465
553
ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
466
554
ScheduleTreeAndAccesses result;
467
555
468
- result. writes = result. reads = isl::union_map::empty (paramSpace);
556
+ Body body (paramSpace);
469
557
470
558
// Walk the IR building a schedule tree
471
559
isl::id_list outer (paramSpace.get_ctx (), 0 );
472
560
auto schedule = makeScheduleTreeHelper (
473
561
s,
474
562
isl::set::universe (paramSpace),
475
563
outer,
476
- &result.reads ,
477
- &result.writes ,
564
+ &body,
478
565
&result.accesses ,
479
566
&result.statements ,
480
567
&result.domains );
481
568
569
+ result.body = body;
482
570
result.tree = fromIslSchedule (schedule);
483
571
484
572
return result;
485
573
}
486
574
487
- std::vector<Reduction> findReductions (const Stmt& s) {
488
- class FindReductions : public IRVisitor {
489
- using IRVisitor::visit;
490
-
491
- bool isReductionUpdate (const Provide* op) {
492
- if (const Call* call = op->values [0 ].as <Call>()) {
493
- return call->is_intrinsic (tc2halide::kReductionUpdate );
494
- } else {
495
- return false ;
496
- }
497
- }
498
-
499
- // Keep track of any reduction variable name for use in visit(Provide*)
500
- void visit (const Variable* op) {
501
- if (op->reduction_domain .defined ()) {
502
- reductionVars.insert (op->name );
503
- }
504
- }
505
-
506
- // Keep track of the names of the outer For nodes.
507
- void visit (const For* op) {
508
- vars.push_back (op->name );
509
- IRVisitor::visit (op);
510
- vars.pop_back ();
511
- }
512
-
513
- // Check if the node is an update node with at least one reduction
514
- // dimension, keeping track of the information about the reduction.
515
- // In particular, collect the positions of the reduction
516
- // dimensions in the update statement domain.
517
- // Visit the children first to ensure that all relevant
518
- // reduction variables have been found first.
519
- void visit (const Provide* op) {
520
- IRVisitor::visit (op);
521
- if (isReductionUpdate (op)) {
522
- std::vector<size_t > dims;
523
- auto n = vars.size ();
524
- for (size_t i = 0 ; i < n; ++i) {
525
- if (reductionVars.count (vars[i]) != 0 ) {
526
- dims.emplace_back (i);
527
- }
528
- }
529
- if (dims.size () > 0 ) {
530
- Reduction p;
531
- p.update = op;
532
- p.dims = dims;
533
- reductions.emplace_back (p);
534
- }
535
- }
536
- }
537
-
538
- public:
539
- // The variables that are known to be reduction variables.
540
- std::unordered_set<std::string> reductionVars;
541
- // The names of the outer For nodes, outermost to innermost.
542
- std::vector<std::string> vars;
543
- std::vector<Reduction> reductions;
544
- } finder;
545
- s.accept (&finder);
546
-
547
- return finder.reductions ;
548
- }
549
-
550
575
} // namespace halide2isl
551
576
} // namespace tc
0 commit comments