@@ -359,28 +359,33 @@ void emitReductionInit(
359
359
context.ss << " ;" << endl;
360
360
}
361
361
362
- void emitCopyStmt (const CodegenStatementContext& context) {
363
- using detail::emitDirectSubscripts;
362
+ namespace {
363
+ template <typename AFF>
364
+ void emitAccess (AFF access, const CodegenStatementContext& context) {
365
+ // Use a temporary isl::ast_build to print the expression.
366
+ // Ideally, this should use the build at the point
367
+ // where the user statement was created.
368
+ auto astBuild = isl::ast_build::from_context (access.domain ());
369
+ context.ss << astBuild.access_from (access).to_C_str ();
370
+ }
371
+ } // namespace
364
372
373
+ void emitCopyStmt (const CodegenStatementContext& context) {
365
374
auto stmtId = context.statementId ();
366
375
367
376
auto iteratorMap = context.iteratorMap ();
368
377
auto promoted = iteratorMap.range_factor_range ();
369
378
auto original = iteratorMap.range_factor_domain ().range_factor_range ();
370
379
auto isRead = stmtId.get_name () == kReadIdName ;
371
- auto originalName = original.get_tuple_id (isl::dim_type::out).get_name ();
372
- auto promotedName = promoted.get_tuple_id (isl::dim_type::out).get_name ();
373
380
374
381
if (isRead) {
375
- context.ss << promotedName;
376
- emitDirectSubscripts (promoted, context);
377
- context.ss << " = " << originalName;
378
- emitDirectSubscripts (original, context);
382
+ emitAccess (isl::multi_pw_aff (promoted), context);
383
+ context.ss << " = " ;
384
+ emitAccess (isl::multi_pw_aff (original), context);
379
385
} else {
380
- context.ss << originalName;
381
- emitDirectSubscripts (original, context);
382
- context.ss << " = " << promotedName;
383
- emitDirectSubscripts (promoted, context);
386
+ emitAccess (isl::multi_pw_aff (original), context);
387
+ context.ss << " = " ;
388
+ emitAccess (isl::multi_pw_aff (promoted), context);
384
389
}
385
390
context.ss << " ;" << std::endl;
386
391
}
@@ -447,14 +452,6 @@ void AstPrinter::emitAst(isl::ast_node node) {
447
452
448
453
namespace detail {
449
454
450
- std::string toString (isl::pw_aff subscript) {
451
- // Use a temporary isl::ast_build to print the expression.
452
- // Ideally, this should use the build at the point
453
- // where the user statement was created.
454
- auto astBuild = isl::ast_build::from_context (subscript.domain ());
455
- return astBuild.expr_from (subscript).to_C_str ();
456
- }
457
-
458
455
isl::pw_aff makeAffFromMappedExpr (
459
456
const Halide::Expr& expr,
460
457
const CodegenStatementContext& context) {
@@ -498,18 +495,35 @@ isl::multi_aff makeMultiAffAccess(
498
495
return ma;
499
496
}
500
497
498
+ namespace {
499
+ bool is_identifier_or_nonnegative_integer (isl::ast_expr expr) {
500
+ if (isl_ast_expr_get_type (expr.get ()) == isl_ast_expr_id)
501
+ return true ;
502
+ if (isl_ast_expr_get_type (expr.get ()) != isl_ast_expr_int)
503
+ return false ;
504
+ return isl::manage (isl_ast_expr_get_val (expr.get ())).is_nonneg ();
505
+ }
506
+ } // namespace
507
+
501
508
void emitHalideExpr (
502
509
const Halide::Expr& e,
503
510
const CodegenStatementContext& context,
504
511
const map<string, string>& substitutions) {
505
512
class EmitHalide : public Halide ::Internal::IRPrinter {
506
513
using Halide::Internal::IRPrinter::visit;
507
514
void visit (const Halide::Internal::Variable* op) {
508
- // This is probably needlessly indirect, given that we just have
509
- // a name to look up somewhere.
510
515
auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr (
511
516
Halide::Expr (op), context);
512
- context.ss << tc::polyhedral::detail::toString (pwAff);
517
+ // Use a temporary isl::ast_build to print the expression.
518
+ // Ideally, this should use the build at the point
519
+ // where the user statement was created.
520
+ auto astBuild = isl::ast_build::from_context (pwAff.domain ());
521
+ auto expr = astBuild.expr_from (pwAff);
522
+ auto s = expr.to_C_str ();
523
+ if (!is_identifier_or_nonnegative_integer (expr)) {
524
+ s = " (" + s + " )" ;
525
+ }
526
+ context.ss << s;
513
527
}
514
528
void visit (const Halide::Internal::Call* op) {
515
529
if (substitutions.count (op->name )) {
@@ -613,19 +627,7 @@ void emitMappedTensorAccess(
613
627
auto astToPromoted =
614
628
isl::pw_multi_aff (promotion).pullback (astToScheduledOriginal);
615
629
616
- auto astBuild = isl::ast_build::from_context (astToPromoted.domain ());
617
- context.ss << astBuild.access_from (astToPromoted).to_C_str ();
618
- }
619
-
620
- void emitDirectSubscripts (
621
- isl::pw_multi_aff subscripts,
622
- const CodegenStatementContext& context) {
623
- auto mpa = isl::multi_pw_aff (subscripts); // this conversion is safe
624
- for (auto pa : isl::MPA (mpa)) {
625
- context.ss << " [" ;
626
- context.ss << toString (pa.pa );
627
- context.ss << " ]" ;
628
- }
630
+ emitAccess (astToPromoted, context);
629
631
}
630
632
631
633
} // namespace detail
0 commit comments