diff --git a/advertorch/attacks/fast_adaptive_boundary.py b/advertorch/attacks/fast_adaptive_boundary.py index 54a77f2..ddd077e 100644 --- a/advertorch/attacks/fast_adaptive_boundary.py +++ b/advertorch/attacks/fast_adaptive_boundary.py @@ -11,8 +11,8 @@ from __future__ import unicode_literals import torch -from torch.autograd.gradcheck import zero_gradients import time +from advertorch.utils import zero_gradients try: from torch import flip diff --git a/advertorch/utils.py b/advertorch/utils.py index 1e826ed..0533b79 100644 --- a/advertorch/utils.py +++ b/advertorch/utils.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import collections def torch_allclose(x, y, rtol=1.e-5, atol=1.e-8): @@ -389,3 +390,13 @@ def set_seed(seed=None): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) + + +def zero_gradients(x): + if isinstance(x, torch.Tensor): + if x.grad is not None: + x.grad.detach_() + x.grad.zero_() + elif isinstance(x, collections.abc.Iterable): + for elem in x: + zero_gradients(elem)