|
27 | 27 | # tol2,
|
28 | 28 | opsToleranceOverride,
|
29 | 29 | check_vmap_fallback,
|
| 30 | + is_batch_norm_training, |
30 | 31 | )
|
31 | 32 | from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
32 | 33 | from functorch import grad, vjp, vmap, jacrev, jacfwd
|
@@ -570,7 +571,9 @@ def vjp_of_vjp(*args_and_cotangents):
|
570 | 571 | result_vjps, _ = tree_flatten(result_vjps)
|
571 | 572 | return (*result, *result_vjps)
|
572 | 573 |
|
573 |
| - generator = get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}, opinfo=op) |
| 574 | + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) |
| 575 | + generator = get_fallback_and_vmap_exhaustive( |
| 576 | + vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training) |
574 | 577 | for loop_out, batched_out in generator:
|
575 | 578 | self.assertEqual(loop_out, batched_out)
|
576 | 579 |
|
@@ -642,7 +645,10 @@ def test_vmapvjp(self, device, dtype, op):
|
642 | 645 | for sample in samples:
|
643 | 646 | cotangents = get_sample_cotangents(op, sample)
|
644 | 647 | fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
|
645 |
| - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op): |
| 648 | + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) |
| 649 | + generator = get_fallback_and_vmap_exhaustive( |
| 650 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) |
| 651 | + for loop_out, batched_out in generator: |
646 | 652 | self.assertEqual(loop_out, batched_out)
|
647 | 653 |
|
648 | 654 | # There are several variations we care about
|
@@ -731,7 +737,10 @@ def test_vmapjvp(self, device, dtype, op):
|
731 | 737 | kwarg_values = sample.kwargs
|
732 | 738 | args = tuple([*arg_values, *kwarg_values])
|
733 | 739 | fn, args = get_jvp_variant(op, sample)
|
734 |
| - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)): |
| 740 | + is_batch_norm_and_training = is_batch_norm_training(op, kwarg_values) |
| 741 | + generator = get_fallback_and_vmap_exhaustive( |
| 742 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, bdims=(0,)) |
| 743 | + for loop_out, batched_out in generator: |
735 | 744 | self.assertEqual(loop_out, batched_out)
|
736 | 745 |
|
737 | 746 | vmapjvpall_fail = {
|
@@ -819,7 +828,10 @@ def test_vmapjvpall(self, device, dtype, op):
|
819 | 828 | kwarg_values = sample.kwargs
|
820 | 829 | args = tuple([*arg_values, *kwarg_values])
|
821 | 830 | fn, args = get_jvp_variant_primals_tangents(op, sample)
|
822 |
| - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op): |
| 831 | + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) |
| 832 | + generator = get_fallback_and_vmap_exhaustive( |
| 833 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) |
| 834 | + for loop_out, batched_out in generator: |
823 | 835 | self.assertEqual(loop_out, batched_out)
|
824 | 836 |
|
825 | 837 | @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
|
@@ -896,8 +908,9 @@ def test():
|
896 | 908 | kwarg_values = sample.kwargs
|
897 | 909 | args = tuple([*arg_values, *kwarg_values])
|
898 | 910 | fn, args = get_jvp_variant_primals_tangents(op, sample)
|
| 911 | + is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values) |
899 | 912 | for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
900 |
| - fn, args, {}, opinfo=op, compute_loop_out=False): |
| 913 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): |
901 | 914 | pass
|
902 | 915 | check_vmap_fallback(self, test, op, dry_run=False)
|
903 | 916 |
|
@@ -1016,13 +1029,14 @@ def test():
|
1016 | 1029 | for sample in samples:
|
1017 | 1030 | cotangents = get_sample_cotangents(op, sample)
|
1018 | 1031 | fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
|
| 1032 | + is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) |
1019 | 1033 | for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
1020 |
| - fn, args, {}, opinfo=op, compute_loop_out=False): |
| 1034 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): |
1021 | 1035 | pass
|
1022 | 1036 | for a_op in op.aliases:
|
1023 | 1037 | fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents)
|
1024 | 1038 | for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
1025 |
| - fn, args, {}, opinfo=op, compute_loop_out=False): |
| 1039 | + fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False): |
1026 | 1040 | pass
|
1027 | 1041 |
|
1028 | 1042 | check_vmap_fallback(self, test, op, dry_run=False)
|
@@ -1447,7 +1461,10 @@ def was_skipped_from_batched_tensors(batched_out, batch_size):
|
1447 | 1461 | for sample_input in sample_inputs:
|
1448 | 1462 | cotangents = get_sample_cotangents(op, sample_input)
|
1449 | 1463 | f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
|
1450 |
| - for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op): |
| 1464 | + is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs) |
| 1465 | + generator = get_fallback_and_vmap_exhaustive( |
| 1466 | + f, args, {}, is_batch_norm_and_training=is_batch_norm_and_training) |
| 1467 | + for loop_out, batched_out in generator: |
1451 | 1468 | if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
|
1452 | 1469 | continue # we weren't able to use the batched tensor in autograd.grad
|
1453 | 1470 | self.assertEqual(loop_out, batched_out)
|
|
0 commit comments