Skip to content

Commit 711c5de

Browse files
committed
Update sgdw for older pytorch
1 parent 60b170b commit 711c5de

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

timm/optim/sgdw.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
from functools import update_wrapper, wraps
12
import torch
23
from torch import Tensor
3-
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
4+
from torch.optim.optimizer import Optimizer
5+
try:
6+
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
7+
has_recent_pt = True
8+
except ImportError:
9+
has_recent_pt = False
10+
411
from typing import List, Optional
512

613
__all__ = ['SGDW', 'sgdw']
@@ -62,7 +69,9 @@ def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
6269

6370
return has_sparse_grad
6471

65-
@_use_grad_for_differentiable
72+
# FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
73+
# without args, for backwards compatibility with old pytorch
74+
@torch.no_grad()
6675
def step(self, closure=None):
6776
"""Performs a single optimization step.
6877
@@ -124,17 +133,19 @@ def sgdw(
124133
125134
See :class:`~torch.optim.SGD` for details.
126135
"""
136+
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
137+
if foreach is None:
138+
# why must we be explicit about an if statement for torch.jit.is_scripting here?
139+
# because JIT can't handle Optionals nor fancy conditionals when scripting
140+
if not torch.jit.is_scripting():
141+
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
142+
else:
143+
foreach = False
127144

128-
if foreach is None:
129-
# why must we be explicit about an if statement for torch.jit.is_scripting here?
130-
# because JIT can't handle Optionals nor fancy conditionals when scripting
131-
if not torch.jit.is_scripting():
132-
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
133-
else:
134-
foreach = False
135-
136-
if foreach and torch.jit.is_scripting():
137-
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
145+
if foreach and torch.jit.is_scripting():
146+
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
147+
else:
148+
foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype
138149

139150
if foreach and not torch.jit.is_scripting():
140151
func = _multi_tensor_sgdw

0 commit comments

Comments
 (0)