@@ -396,10 +396,49 @@ struct LdgWrapper {
396
396
std::ostream& out_;
397
397
};
398
398
399
+ template <typename AFF>
400
+ isl::ast_expr buildAccess (AFF access, const CodegenStatementContext& context) {
401
+ return context.build ().access_from (access);
402
+ }
403
+
404
+ void emitAccess (isl::ast_expr access, const CodegenStatementContext& context) {
405
+ context.ss << access.to_C_str ();
406
+ }
407
+
399
408
template <typename AFF>
400
409
void emitAccess (AFF access, const CodegenStatementContext& context) {
410
+ emitAccess (buildAccess (access, context), context);
411
+ }
412
+
413
+ // Check that the given expression is an access with constant index expressions
414
+ void checkConstantAccess (isl::ast_expr expr) {
415
+ auto op = expr.as <isl::ast_expr_op>();
416
+ auto access = op.as <isl::ast_op_access>();
417
+ TC_CHECK (access);
418
+ for (int i = 1 ; i < access.get_n_arg (); ++i) {
419
+ auto arg = access.get_arg (i);
420
+ TC_CHECK (arg.as <isl::ast_expr_int>())
421
+ << " expected constant subscript, got " << arg.to_C_str ();
422
+ }
423
+ }
424
+
425
+ // Print an access to a(n array of) register(s), checking that
426
+ // the index expressions are constant.
427
+ void emitRegisterAccess (
428
+ isl::pw_multi_aff access,
429
+ const CodegenStatementContext& context) {
430
+ auto expr = buildAccess (access, context);
431
+ checkConstantAccess (expr);
432
+ emitAccess (expr, context);
433
+ }
434
+
435
+ // Print an access to global memory, wrapping the access in an "__ldg()"
436
+ // call if the accessed tensor is known to be read-only.
437
+ void emitGlobalAccess (
438
+ isl::multi_pw_aff access,
439
+ const CodegenStatementContext& context) {
401
440
LdgWrapper ldgWrapper (context, access.get_tuple_id (isl::dim_type::out));
402
- context. ss << context. build (). access_from ( access). to_C_str ( );
441
+ emitAccess ( access, context );
403
442
}
404
443
} // namespace
405
444
@@ -414,9 +453,9 @@ void emitCopyStmt(const CodegenStatementContext& context) {
414
453
if (isRead) {
415
454
emitAccess (isl::multi_pw_aff (promoted), context);
416
455
context.ss << " = " ;
417
- emitAccess (isl::multi_pw_aff (original), context);
456
+ emitGlobalAccess (isl::multi_pw_aff (original), context);
418
457
} else {
419
- emitAccess (isl::multi_pw_aff (original), context);
458
+ emitGlobalAccess (isl::multi_pw_aff (original), context);
420
459
context.ss << " = " ;
421
460
emitAccess (isl::multi_pw_aff (promoted), context);
422
461
}
@@ -625,7 +664,8 @@ void emitMappedTensorAccess(
625
664
return ;
626
665
}
627
666
628
- auto tensorId = context.scop ().promotedDecl (promotionInfo.groupId ).tensorId ;
667
+ auto decl = context.scop ().promotedDecl (promotionInfo.groupId );
668
+ auto tensorId = decl.tensorId ;
629
669
630
670
// Here and below in comments: D = domain, O = original tensor, P = promoted
631
671
// tensor, S = partial schedule, A = AST loops;
@@ -651,7 +691,11 @@ void emitMappedTensorAccess(
651
691
auto astToPromoted =
652
692
isl::pw_multi_aff (promotion).pullback (astToScheduledOriginal);
653
693
654
- emitAccess (astToPromoted, context);
694
+ if (decl.kind == Scop::PromotedDecl::Kind::Register) {
695
+ emitRegisterAccess (astToPromoted, context);
696
+ } else {
697
+ emitAccess (astToPromoted, context);
698
+ }
655
699
}
656
700
657
701
} // namespace detail
0 commit comments