Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6e93aab

Browse files
author
Sven Verdoolaege
committed
tc2halide: properly tag reduction updates
In particular, only tag reduction updates and only tag them once. The original code would tag any non-initial write to a tensor involved in a reduction and it would tag it possibly several times, once for each reduction on that tensor. Tag the reduction updates immediately to ensure that only reduction updates get tagged and that they get tagged exactly once. Note that the tags are now introduced in the function definitions rather than only in the lowered statements. The presence of these intrinsics may hinder some of the analysis passes in Halide's standard lowering path. If these passer ever need to be used, then the intrinsic introduction mechanism may need to be revisited. Multiply tagged expressions cause confusion in TC. In particular, isSupportedReduction would fail to detect the reduction and the code generation also assumes that a reduction update is only tagged once. Tests by Mathieu Fehr.
1 parent d67e305 commit 6e93aab

File tree

2 files changed

+78
-31
lines changed

2 files changed

+78
-31
lines changed

tc/core/tc2halide.cc

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -488,13 +488,15 @@ Expr reductionUpdate(Expr e) {
488488
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
489489
}
490490

491+
// Note that the function definitions created by translateComprehension may
492+
// contain kReductionUpdate intrinsics. These may have to be removed
493+
// in order to be able to apply internal Halide analysis passes on them.
491494
void translateComprehension(
492495
const lang::Comprehension& c,
493496
const map<string, Parameter>& params,
494497
bool throwWarnings,
495498
map<string, Function>* funcs,
496-
FunctionBounds* bounds,
497-
vector<Function>* reductions) {
499+
FunctionBounds* bounds) {
498500
Function f;
499501
auto it = funcs->find(c.ident().name());
500502
if (it != funcs->end()) {
@@ -589,8 +591,9 @@ void translateComprehension(
589591
<< c.assignment()->range().text() << "\n";
590592
}
591593

594+
// Tag reductions as such
592595
if (c.assignment()->kind() != '=') {
593-
reductions->push_back(f);
596+
rhs = reductionUpdate(rhs);
594597
}
595598

596599
// Bind any scalar params on the rhs to their parameter objects.
@@ -739,13 +742,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
739742
components.def = def;
740743
FunctionBounds bounds;
741744

742-
vector<Function> reductions;
743745
for (auto p : def.params()) {
744746
translateParam(p, &components.params, &components.inputs);
745747
}
746748
for (auto c : def.statements()) {
747749
translateComprehension(
748-
c, components.params, throwWarnings, &funcs, &bounds, &reductions);
750+
c, components.params, throwWarnings, &funcs, &bounds);
749751
}
750752
vector<Function> outputs;
751753
for (auto p : def.returns()) {
@@ -803,32 +805,6 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
803805
s = uniquify_variable_names(s);
804806
s = simplify(s);
805807

806-
// Tag reductions as such
807-
for (const Function& f : reductions) {
808-
class TagReduction : public IRMutator2 {
809-
using IRMutator2::visit;
810-
bool found_init = false;
811-
Stmt visit(const Provide* op) override {
812-
if (op->name == f.name()) {
813-
if (found_init) {
814-
return Provide::make(
815-
op->name, {reductionUpdate(op->values[0])}, op->args);
816-
} else {
817-
found_init = true;
818-
return op;
819-
}
820-
} else {
821-
return op;
822-
}
823-
}
824-
const Function& f;
825-
826-
public:
827-
TagReduction(const Function& f) : f(f) {}
828-
} tagReduction(f);
829-
s = tagReduction.mutate(s);
830-
}
831-
832808
// Trim ProducerConsumer annotations. TC doesn't use them.
833809
class RemoveProducerConsumer : public IRMutator2 {
834810
using IRMutator2::visit;

test/test_cuda_mapper.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,77 @@ def fun(float(N) I) -> (O) {
861861
EXPECT_TRUE(code.find("O[0] = (O") != std::string::npos);
862862
}
863863

864+
struct ReductionTest : public PolyhedralMapperTest {
865+
static CudaMappingOptions reductionTestMappingOptions() {
866+
return DefaultOptions()
867+
.outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
868+
.outerScheduleAllowSkewing(false)
869+
.outerSchedulePositiveOrthant(true)
870+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
871+
.intraTileScheduleAllowSkewing(false)
872+
.intraTileSchedulePositiveOrthant(true)
873+
.fixParametersBeforeScheduling(false)
874+
.tile(18, 32)
875+
.unroll(16)
876+
.tileImperfectlyNested(false)
877+
.matchLibraryCalls(true)
878+
.mapToThreads({512})
879+
.mapToBlocks({16384})
880+
.useSharedMemory(true)
881+
.usePrivateMemory(false)
882+
.unrollCopyShared(true);
883+
}
884+
885+
void Check(const string& tc) {
886+
auto code = codegenMapped(tc, reductionTestMappingOptions());
887+
using tc::code::cuda::kCUBReductionName;
888+
EXPECT_TRUE(code.find(kCUBReductionName) != std::string::npos);
889+
}
890+
};
891+
892+
/*
893+
* Check that a reduction library call is produced when the reduction
894+
* instruction is before an instruction modifying the same tensor.
895+
*/
896+
TEST_F(ReductionTest, BeforeInstruction) {
897+
Check(R"TC(
898+
def fun(float(N, K) I) -> (O) {
899+
O(n) +=! I(n, r_n)
900+
O(n) = O(n) / (K)
901+
}
902+
)TC");
903+
}
904+
905+
/*
906+
* Check that a reduction library call is produced when the reduction
907+
* instruction is after an instruction modifying the same tensor.
908+
*/
909+
TEST_F(ReductionTest, AfterInstruction) {
910+
Check(R"TC(
911+
def fun(float(N, K) I, float(N) O0) -> (O) {
912+
O(n) = 0.0 where n in 0:N
913+
O(n) += O0(n)
914+
O(n) += I(n, r_n)
915+
}
916+
)TC");
917+
}
918+
919+
/*
920+
* Check that a reduction library call is produced when the reduction
921+
* instruction is placed after an instruction modifying the same tensor and
922+
* before an instruction modifying the same tensor.
923+
*/
924+
TEST_F(ReductionTest, BetweenInstructions) {
925+
Check(R"TC(
926+
def fun(float(N, K) I, float(N) O0) -> (O) {
927+
O(n) = 0.0 where n in 0:N
928+
O(n) += O0(n)
929+
O(n) += I(n, r_n)
930+
O(n) = O(n) / (K)
931+
}
932+
)TC");
933+
}
934+
864935
static const string kTcMM = R"TC(
865936
def fun(float(M, K) A, float(K, N) B) -> (C) {
866937
C(m, n) +=! A(m, r_k) * B(r_k, n)

0 commit comments

Comments
 (0)