Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 4fa54d5

Browse files
nicolasvasilacheSven Verdoolaege
authored andcommitted
partialTargetTiles: ensure that tiling schedule applies to entire domain
The tiling schedule should be total on the domain. Otherwise, it may drop some domain elements from consideration, skewing the result of the function. In the extreme case, the domain of the schedule is disjoint from the domain, resulting in the perceived absence of any tiles and therefore also of partial tiles. This would cause the new test case to fail. The problem is that in infixScheduleMupa. the combination of a zero-dimensional isl::multi_union_pw_aff with explicit domain (the initial value) and a zero-dimensional isl::multi_union_pw_aff without explicit domain (the one from a band inserted by ScheduleTree::makeEmptyBand) would result in a zero-dimensional isl::multi_union_pw_aff with an _empty_ domain. Arguably, this is a bug in isl. Work around this issue by setting an explicit domain on the partial schedule of the band created by ScheduleTree::makeEmptyBand. This explicit domain is also more in line with the rest of the code base and allows the partial schedule to be converted to an isl::union_map directly, if needed.
1 parent 65faedc commit 4fa54d5

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

tc/core/polyhedral/schedule_tree.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ ScheduleTreeUPtr ScheduleTree::makeEmptyBand(const ScheduleTree* root) {
212212
auto domain = root->elemAs<ScheduleTreeElemDomain>();
213213
CHECK(domain);
214214
auto space = domain->domain_.get_space().set_from_params();
215-
auto zero = isl::multi_union_pw_aff::zero(space);
215+
auto mv = isl::multi_val::zero(space);
216+
auto zero = isl::multi_union_pw_aff(domain->domain_, mv);
216217
return ScheduleTree::makeBand(zero);
217218
}
218219

tc/core/polyhedral/separation.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ isl::union_set partialTargetTiles(
3939
// Mapping between prefix values and target values
4040
// for some common domain element
4141
// P -> T
42+
CHECK(domain.is_subset(scheduleMap.domain()));
4243
auto target = domain.apply(scheduleMap).unwrap();
4344
// Mapping between prefix values and target values
4445
// for some common domain element, extended to complete target tiles.

test/cuda/test_tc_mapper_bugs.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,65 @@ TEST(Convolution, NestedExpressions) {
774774
CHECK_EQ(at::Scalar(B[10]).toFloat(), 1);
775775
}
776776

777+
// Previous versions of TC would map the reduction in the code below
778+
// to CUB, despite the fact that not every thread gets assigned
779+
// an instance of the reduction.
780+
// Check that this no longer happens.
781+
TEST(GroupNorm, ReductionDeadlockBug) {
782+
auto group_norm = "group_norm";
783+
auto TC = std::string(R"TC(
784+
def group_norm(
785+
float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
786+
-> (O, mean, var)
787+
{
788+
mean(n, g) +=! I(n, g, r_d, r_h, r_w)
789+
var(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
790+
O(n, g, d, h, w) = gamma(g, d)
791+
* ( I(n, g, d, h, w) - mean(n, g) * 1.0 )
792+
* rsqrt( var(n, g) * 1.0
793+
- mean(n, g) * mean(n, g) * 1.0 * 1.0
794+
+ 1e-5)
795+
+ beta(g, d)
796+
}
797+
)TC");
798+
799+
uint32_t N = 4, C = 8, G = 4, D = C / G, H = 6, W = 6;
800+
at::Tensor I = at::CUDA(at::kFloat).rand({N, G, D, H, W});
801+
at::Tensor gamma = at::CUDA(at::kFloat).rand({G, D}).fill_(1.0f);
802+
at::Tensor beta = at::CUDA(at::kFloat).rand({G, D}).fill_(0.0f);
803+
std::vector<at::Tensor> inputs = {I, gamma, beta};
804+
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions()
805+
.outerScheduleFusionStrategy(tc::FusionStrategy::Min)
806+
.outerScheduleAllowSkewing(false)
807+
.outerSchedulePositiveOrthant(true)
808+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
809+
.intraTileScheduleAllowSkewing(false)
810+
.intraTileSchedulePositiveOrthant(true)
811+
.tile(2, 6, 8, 48)
812+
.unroll(4)
813+
.tileImperfectlyNested(false)
814+
.matchLibraryCalls(true)
815+
.mapToThreads(6, 12)
816+
.mapToBlocks(8)
817+
.useSharedMemory(true)
818+
.usePrivateMemory(true)
819+
.unrollCopyShared(false);
820+
auto pExecutor =
821+
tc::aten::compile<tc::CudaBackend>(TC, group_norm, inputs, options);
822+
auto outputs = tc::aten::prepareOutputs(TC, group_norm, inputs);
823+
tc::aten::run(*pExecutor, inputs, outputs);
824+
cudaDeviceSynchronize();
825+
826+
auto v = I.view({N, G, -1});
827+
auto mean = v.mean(-1, true);
828+
auto var = v.var(-1, true).view({N, G, 1});
829+
auto x = (v - mean) / (var + 1e-5f).sqrt();
830+
auto y = x.view({N, G, D, H, W});
831+
cudaDeviceSynchronize();
832+
833+
checkRtol(outputs[0] - y, {I}, D * H * W, 1e-6);
834+
}
835+
777836
int main(int argc, char** argv) {
778837
::testing::InitGoogleTest(&argc, argv);
779838
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)