@@ -774,6 +774,65 @@ TEST(Convolution, NestedExpressions) {
774
774
CHECK_EQ (at::Scalar (B[10 ]).toFloat (), 1 );
775
775
}
776
776
777
+ // Previous versions of TC would map the reduction in the code below
778
+ // to CUB, despite the fact that not every thread gets assigned
779
+ // an instance of the reduction.
780
+ // Check that this no longer happens.
781
+ TEST (GroupNorm, ReductionDeadlockBug) {
782
+ auto group_norm = " group_norm" ;
783
+ auto TC = std::string (R"TC(
784
+ def group_norm(
785
+ float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
786
+ -> (O, mean, var)
787
+ {
788
+ mean(n, g) +=! I(n, g, r_d, r_h, r_w)
789
+ var(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
790
+ O(n, g, d, h, w) = gamma(g, d)
791
+ * ( I(n, g, d, h, w) - mean(n, g) * 1.0 )
792
+ * rsqrt( var(n, g) * 1.0
793
+ - mean(n, g) * mean(n, g) * 1.0 * 1.0
794
+ + 1e-5)
795
+ + beta(g, d)
796
+ }
797
+ )TC" );
798
+
799
+ uint32_t N = 4 , C = 8 , G = 4 , D = C / G, H = 6 , W = 6 ;
800
+ at::Tensor I = at::CUDA (at::kFloat ).rand ({N, G, D, H, W});
801
+ at::Tensor gamma = at::CUDA (at::kFloat ).rand ({G, D}).fill_ (1 .0f );
802
+ at::Tensor beta = at::CUDA (at::kFloat ).rand ({G, D}).fill_ (0 .0f );
803
+ std::vector<at::Tensor> inputs = {I, gamma, beta};
804
+ auto options = tc::CudaMappingOptions::makeNaiveMappingOptions ()
805
+ .outerScheduleFusionStrategy (tc::FusionStrategy::Min)
806
+ .outerScheduleAllowSkewing (false )
807
+ .outerSchedulePositiveOrthant (true )
808
+ .intraTileScheduleFusionStrategy (tc::FusionStrategy::Min)
809
+ .intraTileScheduleAllowSkewing (false )
810
+ .intraTileSchedulePositiveOrthant (true )
811
+ .tile (2 , 6 , 8 , 48 )
812
+ .unroll (4 )
813
+ .tileImperfectlyNested (false )
814
+ .matchLibraryCalls (true )
815
+ .mapToThreads (6 , 12 )
816
+ .mapToBlocks (8 )
817
+ .useSharedMemory (true )
818
+ .usePrivateMemory (true )
819
+ .unrollCopyShared (false );
820
+ auto pExecutor =
821
+ tc::aten::compile<tc::CudaBackend>(TC, group_norm, inputs, options);
822
+ auto outputs = tc::aten::prepareOutputs (TC, group_norm, inputs);
823
+ tc::aten::run (*pExecutor, inputs, outputs);
824
+ cudaDeviceSynchronize ();
825
+
826
+ auto v = I.view ({N, G, -1 });
827
+ auto mean = v.mean (-1 , true );
828
+ auto var = v.var (-1 , true ).view ({N, G, 1 });
829
+ auto x = (v - mean) / (var + 1e-5f ).sqrt ();
830
+ auto y = x.view ({N, G, D, H, W});
831
+ cudaDeviceSynchronize ();
832
+
833
+ checkRtol (outputs[0 ] - y, {I}, D * H * W, 1e-6 );
834
+ }
835
+
777
836
int main (int argc, char ** argv) {
778
837
::testing::InitGoogleTest (&argc, argv);
779
838
::gflags::ParseCommandLineFlags (&argc, &argv, true );
0 commit comments