Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6c8a77e

Browse files
Merge pull request #417 from facebookresearch/pr/reduction-deadlock
partialTargetTiles: ensure that tiling schedule applies to entire domain
2 parents 9cc088f + 4fa54d5 commit 6c8a77e

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

tc/core/polyhedral/schedule_tree.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ ScheduleTreeUPtr ScheduleTree::makeEmptyBand(const ScheduleTree* root) {
212212
auto domain = root->elemAs<ScheduleTreeElemDomain>();
213213
CHECK(domain);
214214
auto space = domain->domain_.get_space().set_from_params();
215-
auto zero = isl::multi_union_pw_aff::zero(space);
215+
auto mv = isl::multi_val::zero(space);
216+
auto zero = isl::multi_union_pw_aff(domain->domain_, mv);
216217
return ScheduleTree::makeBand(zero);
217218
}
218219

tc/core/polyhedral/separation.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ isl::union_set partialTargetTiles(
3939
// Mapping between prefix values and target values
4040
// for some common domain element
4141
// P -> T
42+
CHECK(domain.is_subset(scheduleMap.domain()));
4243
auto target = domain.apply(scheduleMap).unwrap();
4344
// Mapping between prefix values and target values
4445
// for some common domain element, extended to complete target tiles.

test/cuda/test_tc_mapper_bugs.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,65 @@ TEST(Convolution, NestedExpressions) {
774774
CHECK_EQ(at::Scalar(B[10]).toFloat(), 1);
775775
}
776776

777+
// Previous versions of TC would map the reduction in the code below
778+
// to CUB, despite the fact that not every thread gets assigned
779+
// an instance of the reduction.
780+
// Check that this no longer happens.
781+
TEST(GroupNorm, ReductionDeadlockBug) {
782+
auto group_norm = "group_norm";
783+
auto TC = std::string(R"TC(
784+
def group_norm(
785+
float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
786+
-> (O, mean, var)
787+
{
788+
mean(n, g) +=! I(n, g, r_d, r_h, r_w)
789+
var(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
790+
O(n, g, d, h, w) = gamma(g, d)
791+
* ( I(n, g, d, h, w) - mean(n, g) * 1.0 )
792+
* rsqrt( var(n, g) * 1.0
793+
- mean(n, g) * mean(n, g) * 1.0 * 1.0
794+
+ 1e-5)
795+
+ beta(g, d)
796+
}
797+
)TC");
798+
799+
uint32_t N = 4, C = 8, G = 4, D = C / G, H = 6, W = 6;
800+
at::Tensor I = at::CUDA(at::kFloat).rand({N, G, D, H, W});
801+
at::Tensor gamma = at::CUDA(at::kFloat).rand({G, D}).fill_(1.0f);
802+
at::Tensor beta = at::CUDA(at::kFloat).rand({G, D}).fill_(0.0f);
803+
std::vector<at::Tensor> inputs = {I, gamma, beta};
804+
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions()
805+
.outerScheduleFusionStrategy(tc::FusionStrategy::Min)
806+
.outerScheduleAllowSkewing(false)
807+
.outerSchedulePositiveOrthant(true)
808+
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Min)
809+
.intraTileScheduleAllowSkewing(false)
810+
.intraTileSchedulePositiveOrthant(true)
811+
.tile(2, 6, 8, 48)
812+
.unroll(4)
813+
.tileImperfectlyNested(false)
814+
.matchLibraryCalls(true)
815+
.mapToThreads(6, 12)
816+
.mapToBlocks(8)
817+
.useSharedMemory(true)
818+
.usePrivateMemory(true)
819+
.unrollCopyShared(false);
820+
auto pExecutor =
821+
tc::aten::compile<tc::CudaBackend>(TC, group_norm, inputs, options);
822+
auto outputs = tc::aten::prepareOutputs(TC, group_norm, inputs);
823+
tc::aten::run(*pExecutor, inputs, outputs);
824+
cudaDeviceSynchronize();
825+
826+
auto v = I.view({N, G, -1});
827+
auto mean = v.mean(-1, true);
828+
auto var = v.var(-1, true).view({N, G, 1});
829+
auto x = (v - mean) / (var + 1e-5f).sqrt();
830+
auto y = x.view({N, G, D, H, W});
831+
cudaDeviceSynchronize();
832+
833+
checkRtol(outputs[0] - y, {I}, D * H * W, 1e-6);
834+
}
835+
777836
int main(int argc, char** argv) {
778837
::testing::InitGoogleTest(&argc, argv);
779838
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)