Skip to content

Commit a6e0e61

Browse files
author
Samantha Andow
committed
Allow batch norm with all variations of batching when training=False (#958)
* allow batch norm with all variations of batching when training=False * make running mean/var always call contiguous
1 parent 0331c43 commit a6e0e61

File tree

4 files changed

+59
-25
lines changed

4 files changed

+59
-25
lines changed

functorch/csrc/BatchRulesNorm.cpp

+3-9
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ batch_norm_batch_rule(
5858
auto running_mean = *running_mean_maybe_owned;
5959
c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
6060
auto running_var = *running_var_maybe_owned;
61-
TORCH_CHECK(!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim)),
61+
TORCH_CHECK(!training || (!input_bdim || ((!running_mean.defined() || running_mean_bdim) && (!running_var.defined() || running_var_bdim))),
6262
"Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, ",
6363
"were not batched.\nIf you are using a module and do not need eval mode, please set `track_running_stats` to be False.",
6464
"If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on ",
@@ -85,18 +85,12 @@ batch_norm_batch_rule(
8585
if (running_mean.defined()) {
8686
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
8787
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
88-
running_mean_ = reshape_dim_into(0, 0, *running_mean_);
89-
if (training) {
90-
running_mean_ = running_mean_->contiguous();
91-
}
88+
running_mean_ = reshape_dim_into(0, 0, *running_mean_).contiguous();
9289
}
9390
if (running_var.defined()) {
9491
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
9592
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
96-
running_var_ = reshape_dim_into(0, 0, *running_var_);
97-
if (training) {
98-
running_var_ = running_var_->contiguous();
99-
}
93+
running_var_ = reshape_dim_into(0, 0, *running_var_).contiguous();
10094
}
10195

10296
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight

test/common_utils.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,29 @@ def get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch
115115
batch_size=batch_size, bdims=bdims, for_batch_norm=True)
116116

117117

118-
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)):
118+
def is_batch_norm_training(op_name, kwarg_values):
119+
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
120+
if op_name not in batch_norm_fns:
121+
return False
122+
123+
# batch norm and instance norm require the value to be a plain bool
124+
default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't
125+
is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool))
126+
if len(is_training) == 0:
127+
return default_training
128+
else:
129+
assert len(is_training) == 1
130+
return is_training[0]
131+
132+
133+
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True, bdims=(0, -1)):
119134
out_dim = 0
120135
batch_size = 4
121-
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)
122-
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
123-
if opinfo is not None and opinfo.name in batch_norm_fns:
136+
if is_batch_norm_and_training:
124137
generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims)
138+
else:
139+
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims)
140+
125141
for batched_args, in_dims, kwarg_values in generator:
126142
if compute_loop_out:
127143
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)

test/test_ops.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# tol2,
2828
opsToleranceOverride,
2929
check_vmap_fallback,
30+
is_batch_norm_training,
3031
)
3132
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
3233
from functorch import grad, vjp, vmap, jacrev, jacfwd
@@ -570,7 +571,9 @@ def vjp_of_vjp(*args_and_cotangents):
570571
result_vjps, _ = tree_flatten(result_vjps)
571572
return (*result, *result_vjps)
572573

573-
generator = get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}, opinfo=op)
574+
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
575+
generator = get_fallback_and_vmap_exhaustive(
576+
vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training)
574577
for loop_out, batched_out in generator:
575578
self.assertEqual(loop_out, batched_out)
576579

@@ -642,7 +645,10 @@ def test_vmapvjp(self, device, dtype, op):
642645
for sample in samples:
643646
cotangents = get_sample_cotangents(op, sample)
644647
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
645-
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
648+
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
649+
generator = get_fallback_and_vmap_exhaustive(
650+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
651+
for loop_out, batched_out in generator:
646652
self.assertEqual(loop_out, batched_out)
647653

648654
# There are several variations we care about
@@ -731,7 +737,10 @@ def test_vmapjvp(self, device, dtype, op):
731737
kwarg_values = sample.kwargs
732738
args = tuple([*arg_values, *kwarg_values])
733739
fn, args = get_jvp_variant(op, sample)
734-
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)):
740+
is_batch_norm_and_training = is_batch_norm_training(op, kwarg_values)
741+
generator = get_fallback_and_vmap_exhaustive(
742+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, bdims=(0,))
743+
for loop_out, batched_out in generator:
735744
self.assertEqual(loop_out, batched_out)
736745

737746
vmapjvpall_fail = {
@@ -819,7 +828,10 @@ def test_vmapjvpall(self, device, dtype, op):
819828
kwarg_values = sample.kwargs
820829
args = tuple([*arg_values, *kwarg_values])
821830
fn, args = get_jvp_variant_primals_tangents(op, sample)
822-
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
831+
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
832+
generator = get_fallback_and_vmap_exhaustive(
833+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
834+
for loop_out, batched_out in generator:
823835
self.assertEqual(loop_out, batched_out)
824836

825837
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
@@ -896,8 +908,9 @@ def test():
896908
kwarg_values = sample.kwargs
897909
args = tuple([*arg_values, *kwarg_values])
898910
fn, args = get_jvp_variant_primals_tangents(op, sample)
911+
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
899912
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
900-
fn, args, {}, opinfo=op, compute_loop_out=False):
913+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
901914
pass
902915
check_vmap_fallback(self, test, op, dry_run=False)
903916

@@ -1016,13 +1029,14 @@ def test():
10161029
for sample in samples:
10171030
cotangents = get_sample_cotangents(op, sample)
10181031
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
1032+
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
10191033
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
1020-
fn, args, {}, opinfo=op, compute_loop_out=False):
1034+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
10211035
pass
10221036
for a_op in op.aliases:
10231037
fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents)
10241038
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
1025-
fn, args, {}, opinfo=op, compute_loop_out=False):
1039+
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
10261040
pass
10271041

10281042
check_vmap_fallback(self, test, op, dry_run=False)
@@ -1447,7 +1461,10 @@ def was_skipped_from_batched_tensors(batched_out, batch_size):
14471461
for sample_input in sample_inputs:
14481462
cotangents = get_sample_cotangents(op, sample_input)
14491463
f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
1450-
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op):
1464+
is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
1465+
generator = get_fallback_and_vmap_exhaustive(
1466+
f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
1467+
for loop_out, batched_out in generator:
14511468
if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
14521469
continue # we weren't able to use the batched tensor in autograd.grad
14531470
self.assertEqual(loop_out, batched_out)

test/test_vmap.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
check_vmap_fallback,
3535
tol1,
3636
opsToleranceOverride,
37+
is_batch_norm_training,
3738
)
3839
import types
3940
from collections import namedtuple
@@ -3148,16 +3149,19 @@ def test_vmap_exhaustive(self, device, dtype, op):
31483149
for sample_input in sample_inputs_itr:
31493150
arg_values = [sample_input.input] + list(sample_input.args)
31503151
kwarg_values = sample_input.kwargs
3152+
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
31513153
try:
3152-
generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op)
3154+
generator = get_fallback_and_vmap_exhaustive(
3155+
op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
31533156
for loop_out, batched_out in generator:
31543157
# empty_like and new_empty produce garbage values so we just check the shapes.
31553158
if op.name == 'empty_like' or op.name == 'new_empty':
31563159
self.assertEqual(loop_out.shape, batched_out.shape)
31573160
continue
31583161
self.assertEqual(loop_out, batched_out)
31593162
for a_op in op.aliases:
3160-
a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op)
3163+
a_generator = get_fallback_and_vmap_exhaustive(
3164+
a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
31613165
for loop_out, batched_out in a_generator:
31623166
self.assertEqual(loop_out, batched_out)
31633167
# todo(chilli): Garbage hack I added to deal with indexing not working
@@ -3294,15 +3298,18 @@ def test():
32943298
for sample_input in sample_inputs_itr:
32953299
arg_values = [sample_input.input] + list(sample_input.args)
32963300
kwarg_values = sample_input.kwargs
3297-
generator = get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, opinfo=op)
3301+
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
3302+
generator = get_fallback_and_vmap_exhaustive(
3303+
op.op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
32983304
for loop_out, batched_out in generator:
32993305
# empty_like and new_empty produce garbage values so we just check the shapes.
33003306
if op.name == 'empty_like' or op.name == 'new_empty':
33013307
self.assertEqual(loop_out.shape, batched_out.shape)
33023308
continue
33033309
self.assertEqual(loop_out, batched_out)
33043310
for a_op in op.aliases:
3305-
a_generator = get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, opinfo=op)
3311+
a_generator = get_fallback_and_vmap_exhaustive(
3312+
a_op, arg_values, kwarg_values, is_batch_norm_and_training=is_batch_norm_and_training)
33063313
for loop_out, batched_out in a_generator:
33073314
self.assertEqual(loop_out, batched_out)
33083315
check_vmap_fallback(self, test, op)

0 commit comments

Comments
 (0)