Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 69e85e6

Browse files
Reactivate and fix SpatialBatchNorm test
1 parent 0527be5 commit 69e85e6

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

test/cuda/test_tc_mapper.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,11 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
355355
///////////////////////////////////////////////////////////////////////////////
356356
// SpatialBatchNormalization
357357
///////////////////////////////////////////////////////////////////////////////
358-
TEST_F(TcCudaMapperTest, DISABLED_SpatialBatchNormalization) {
358+
TEST_F(TcCudaMapperTest, SpatialBatchNormalization) {
359359
N = 32;
360-
at::Tensor eps = at::CUDA(at::kFloat).rand({});
360+
at::Tensor eps = at::CUDA(at::kFloat).rand({1});
361361
eps[0] = 1.0f;
362-
at::Tensor momentum = at::CUDA(at::kFloat).rand({});
362+
at::Tensor momentum = at::CUDA(at::kFloat).rand({1});
363363
momentum[0] = 1.0;
364364
at::Tensor I = at::CUDA(at::kFloat).rand({N, C2, H, W});
365365
at::Tensor rMeanIn = at::CUDA(at::kFloat).rand({C2});
@@ -369,21 +369,21 @@ TEST_F(TcCudaMapperTest, DISABLED_SpatialBatchNormalization) {
369369

370370
static constexpr auto TC = R"TC(
371371
def spatial_batch_norm(
372-
float momentum, float eps,
372+
float(1) momentum, float(1) eps,
373373
float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
374374
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
375375
{
376376
mean(c) +=! I(r_n, c, r_h, r_w)
377377
mean(c) = mean(c) / (N * H * W)
378-
rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c)
378+
rMeanOut(c) = (1 - momentum(0)) * rMeanIn(c) + momentum(0) * mean(c)
379379
380380
centered(n, c, h, w) = I( n, c, h, w) - rMeanOut(c)
381381
variance(n, c, h, w) = centered( n, c, h, w) * centered(n, c, h, w)
382-
expectedVariance(c) +=! (variance(r_n, c, r_h, r_w) + eps) / (N * H * W)
382+
expectedVariance(c) +=! (variance(r_n, c, r_h, r_w) + eps(0)) / (N * H * W)
383383
384384
rVarOut(c) = rsqrt(
385-
(1 - momentum) * rVarIn(c) +
386-
momentum * expectedVariance(c))
385+
(1 - momentum(0)) * rVarIn(c) +
386+
momentum(0) * expectedVariance(c))
387387
388388
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
389389
normalizedOut(n, c, h, w) = O(n, c, h, w)
@@ -406,8 +406,8 @@ def spatial_batch_norm(
406406
rMeanIn,
407407
rVarIn,
408408
training,
409-
at::Scalar(momentum).toFloat(),
410-
at::Scalar(eps).toFloat(),
409+
at::Scalar(momentum[0]).toFloat(),
410+
at::Scalar(eps[0]).toFloat(),
411411
save_mean,
412412
save_std);
413413
auto diff = O.sub(outputs[0]);

0 commit comments

Comments
 (0)