@@ -355,11 +355,11 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
355
355
// /////////////////////////////////////////////////////////////////////////////
356
356
// SpatialBatchNormalization
357
357
// /////////////////////////////////////////////////////////////////////////////
358
- TEST_F (TcCudaMapperTest, DISABLED_SpatialBatchNormalization ) {
358
+ TEST_F (TcCudaMapperTest, SpatialBatchNormalization ) {
359
359
N = 32 ;
360
- at::Tensor eps = at::CUDA (at::kFloat ).rand ({});
360
+ at::Tensor eps = at::CUDA (at::kFloat ).rand ({1 });
361
361
eps[0 ] = 1 .0f ;
362
- at::Tensor momentum = at::CUDA (at::kFloat ).rand ({});
362
+ at::Tensor momentum = at::CUDA (at::kFloat ).rand ({1 });
363
363
momentum[0 ] = 1.0 ;
364
364
at::Tensor I = at::CUDA (at::kFloat ).rand ({N, C2, H, W});
365
365
at::Tensor rMeanIn = at::CUDA (at::kFloat ).rand ({C2});
@@ -369,21 +369,21 @@ TEST_F(TcCudaMapperTest, DISABLED_SpatialBatchNormalization) {
369
369
370
370
static constexpr auto TC = R"TC(
371
371
def spatial_batch_norm(
372
- float momentum, float eps,
372
+ float(1) momentum, float(1) eps,
373
373
float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
374
374
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
375
375
{
376
376
mean(c) +=! I(r_n, c, r_h, r_w)
377
377
mean(c) = mean(c) / (N * H * W)
378
- rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c)
378
+ rMeanOut(c) = (1 - momentum(0)) * rMeanIn(c) + momentum(0) * mean(c)
379
379
380
380
centered(n, c, h, w) = I( n, c, h, w) - rMeanOut(c)
381
381
variance(n, c, h, w) = centered( n, c, h, w) * centered(n, c, h, w)
382
- expectedVariance(c) +=! (variance(r_n, c, r_h, r_w) + eps) / (N * H * W)
382
+ expectedVariance(c) +=! (variance(r_n, c, r_h, r_w) + eps(0) ) / (N * H * W)
383
383
384
384
rVarOut(c) = rsqrt(
385
- (1 - momentum) * rVarIn(c) +
386
- momentum * expectedVariance(c))
385
+ (1 - momentum(0) ) * rVarIn(c) +
386
+ momentum(0) * expectedVariance(c))
387
387
388
388
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
389
389
normalizedOut(n, c, h, w) = O(n, c, h, w)
@@ -406,8 +406,8 @@ def spatial_batch_norm(
406
406
rMeanIn,
407
407
rVarIn,
408
408
training,
409
- at::Scalar (momentum).toFloat (),
410
- at::Scalar (eps).toFloat (),
409
+ at::Scalar (momentum[ 0 ] ).toFloat (),
410
+ at::Scalar (eps[ 0 ] ).toFloat (),
411
411
save_mean,
412
412
save_std);
413
413
auto diff = O.sub (outputs[0 ]);
0 commit comments