21
21
#include < utility>
22
22
23
23
#include " tc/core/flags.h"
24
- #include " tc/core/halide2isl.h"
25
24
#include " tc/core/islpp_wrap.h"
26
25
#include " tc/core/libraries.h"
27
26
#include " tc/core/polyhedral/codegen.h"
@@ -368,11 +367,7 @@ void emitReductionInit(
368
367
namespace {
369
368
template <typename AFF>
370
369
void emitAccess (AFF access, const CodegenStatementContext& context) {
371
- // Use a temporary isl::ast_build to print the expression.
372
- // Ideally, this should use the build at the point
373
- // where the user statement was created.
374
- auto astBuild = isl::ast_build::from_context (access.domain ());
375
- context.ss << astBuild.access_from (access).to_C_str ();
370
+ context.ss << context.build ().access_from (access).to_C_str ();
376
371
}
377
372
} // namespace
378
373
@@ -401,6 +396,8 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
401
396
auto stmtId = usrExp.get_op_arg (0 ).get_id ();
402
397
auto nodeId = node.get_annotation ();
403
398
auto statementContext = CodegenStatementContext (context_, nodeId);
399
+ CHECK_EQ (context_.nodeInfoMap .count (nodeId), 1 )
400
+ << " no info for node " << nodeId;
404
401
405
402
WS ws;
406
403
context_.ss << ws.tab ();
@@ -414,8 +411,6 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
414
411
emitReductionInit (stmtId, updateId, context_);
415
412
inReduction_ = true ;
416
413
} else if (inReduction_ && context_.scop ().isReductionUpdate (stmtId)) {
417
- CHECK_EQ (context_.iteratorMaps .count (nodeId), 1 )
418
- << " no iterator remapping for op " << nodeId;
419
414
emitReductionUpdate (stmtId, statementContext);
420
415
reductionUpdateNodeId_ = nodeId;
421
416
} else if (context_.scop ().isSyncId (stmtId)) {
@@ -424,14 +419,11 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
424
419
stmtId.get_name () == kReadIdName || stmtId.get_name () == kWriteIdName ) {
425
420
emitCopyStmt (statementContext);
426
421
} else { // regular statement
427
- CHECK_EQ (context_.iteratorMaps .count (nodeId), 1 )
428
- << " no iterator remapping for op " << nodeId;
429
- auto mappedStmtId =
430
- context_.iteratorMaps .at (nodeId).get_tuple_id (isl::dim_type::out);
422
+ auto mappedStmtId = statementContext.statementId ();
431
423
CHECK_EQ (stmtId, mappedStmtId)
432
424
<< " statement ids in expr (" << stmtId << " ) and in iteratorMaps ("
433
425
<< mappedStmtId << " ) do not match" ;
434
- emitUserStmt (stmtId, CodegenStatementContext (context_, nodeId) );
426
+ emitUserStmt (stmtId, statementContext );
435
427
}
436
428
}
437
429
@@ -461,11 +453,10 @@ namespace detail {
461
453
isl::pw_aff makeAffFromMappedExpr (
462
454
const Halide::Expr& expr,
463
455
const CodegenStatementContext& context) {
464
- auto space = context.iteratorMap ().get_space ().range ();
465
456
// We only expect this to be called on encountering a free
466
457
// variable. Compound expressions should be emitted as Halide.
467
458
CHECK (expr.as <Halide::Internal::Variable>());
468
- auto aff = halide2isl:: makeIslAffFromExpr (space, expr);
459
+ auto aff = context. makeIslAffFromExpr (expr);
469
460
auto pwaff = isl::pw_aff (aff).pullback (context.iteratorMap ());
470
461
return pwaff;
471
462
}
@@ -495,8 +486,7 @@ isl::multi_aff makeMultiAffAccess(
495
486
496
487
auto ma = isl::multi_aff::zero (space);
497
488
for (size_t i = 0 ; i < subscripts.size (); ++i) {
498
- ma = ma.set_aff (
499
- i, halide2isl::makeIslAffFromExpr (domainSpace, subscripts[i]));
489
+ ma = ma.set_aff (i, context.makeIslAffFromExpr (subscripts[i]));
500
490
}
501
491
return ma;
502
492
}
@@ -520,11 +510,7 @@ void emitHalideExpr(
520
510
void visit (const Halide::Internal::Variable* op) {
521
511
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr (
522
512
Halide::Expr (op), context);
523
- // Use a temporary isl::ast_build to print the expression.
524
- // Ideally, this should use the build at the point
525
- // where the user statement was created.
526
- auto astBuild = isl::ast_build::from_context (pwAff.domain ());
527
- auto expr = astBuild.expr_from (pwAff);
513
+ auto expr = context.build ().expr_from (pwAff);
528
514
auto s = expr.to_C_str ();
529
515
if (!is_identifier_or_nonnegative_integer (expr)) {
530
516
s = " (" + s + " )" ;
@@ -724,42 +710,32 @@ string emitCudaKernel(
724
710
emitTensorViews (ss, scop.halide .inputs , paramValues);
725
711
emitTmpDecl (ss, scop);
726
712
emitPromotedArrayViewsHalide (ss, scop);
727
- IteratorMapsType iteratorMaps ;
728
- auto collect = [&iteratorMaps ](
713
+ NodeInfoMapType nodeInfoMap ;
714
+ auto collect = [&nodeInfoMap ](
729
715
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
730
716
auto collectIteratorMaps =
731
717
[](isl::ast_node node,
732
718
isl::ast_build build,
733
- IteratorMapsType* iteratorMaps ) -> isl::ast_node {
719
+ NodeInfoMapType* nodeInfoMap ) -> isl::ast_node {
734
720
auto user = node.as <isl::ast_node_user>();
735
721
CHECK (user);
736
722
auto expr = user.get_expr ();
737
723
auto stmtId = expr.get_op_arg (0 ).get_id ();
738
- // We rename loop-related dimensions manually.
739
724
auto schedule = build.get_schedule ();
740
- auto scheduleSpace = build.get_schedule_space ();
741
725
auto scheduleMap = isl::map::from_union_map (schedule);
742
726
743
727
auto nodeId = isl::id (
744
728
node.get_ctx (),
745
729
std::string (kAstNodeIdPrefix ) + std::to_string (nAstNodes ()++));
746
- CHECK_EQ (0 , iteratorMaps->count (nodeId)) << " entry exists: " << nodeId;
747
- CHECK_EQ (
748
- scheduleMap.dim (isl::dim_type::out),
749
- scheduleSpace.dim (isl::dim_type::set));
750
- for (int i = 0 ; i < scheduleSpace.dim (isl::dim_type::set); ++i) {
751
- scheduleMap = scheduleMap.set_dim_id (
752
- isl::dim_type::out,
753
- i,
754
- scheduleSpace.get_dim_id (isl::dim_type::set, i));
755
- }
730
+ CHECK_EQ (0 , nodeInfoMap->count (nodeId)) << " entry exists: " << nodeId;
756
731
757
- auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
758
- iteratorMaps->emplace (nodeId, iteratorMap);
732
+ auto & nodeInfo = (*nodeInfoMap)[nodeId];
733
+ nodeInfo.iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
734
+ nodeInfo.build = build;
759
735
return node.set_annotation (nodeId);
760
736
};
761
737
762
- return collectIteratorMaps (n, b, &iteratorMaps );
738
+ return collectIteratorMaps (n, b, &nodeInfoMap );
763
739
};
764
740
765
741
auto bands = detail::ScheduleTree::collect (
@@ -781,7 +757,7 @@ string emitCudaKernel(
781
757
astBuild = astBuild.set_at_each_domain (collect);
782
758
astBuild = astBuild.set_iterators (Codegen::makeLoopIterators (ctx, maxDepth));
783
759
auto astNode = astBuild.node_from (schedule);
784
- AstPrinter (CodegenContext (ss, mscop, iteratorMaps )).emit (astNode);
760
+ AstPrinter (CodegenContext (ss, mscop, nodeInfoMap )).emit (astNode);
785
761
ss << " }" << endl;
786
762
787
763
return ss.str ();
0 commit comments