|
| 1 | +from functools import update_wrapper, wraps |
1 | 2 | import torch
|
2 | 3 | from torch import Tensor
|
3 |
| -from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach |
| 4 | +from torch.optim.optimizer import Optimizer |
| 5 | +try: |
| 6 | + from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach |
| 7 | + has_recent_pt = True |
| 8 | +except ImportError: |
| 9 | + has_recent_pt = False |
| 10 | + |
4 | 11 | from typing import List, Optional
|
5 | 12 |
|
6 | 13 | __all__ = ['SGDW', 'sgdw']
|
@@ -62,7 +69,9 @@ def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
|
62 | 69 |
|
63 | 70 | return has_sparse_grad
|
64 | 71 |
|
65 |
| - @_use_grad_for_differentiable |
| 72 | + # FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator |
| 73 | + # without args, for backwards compatibility with old pytorch |
| 74 | + @torch.no_grad() |
66 | 75 | def step(self, closure=None):
|
67 | 76 | """Performs a single optimization step.
|
68 | 77 |
|
@@ -124,17 +133,19 @@ def sgdw(
|
124 | 133 |
|
125 | 134 | See :class:`~torch.optim.SGD` for details.
|
126 | 135 | """
|
| 136 | + if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'): |
| 137 | + if foreach is None: |
| 138 | + # why must we be explicit about an if statement for torch.jit.is_scripting here? |
| 139 | + # because JIT can't handle Optionals nor fancy conditionals when scripting |
| 140 | + if not torch.jit.is_scripting(): |
| 141 | + _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
| 142 | + else: |
| 143 | + foreach = False |
127 | 144 |
|
128 |
| - if foreach is None: |
129 |
| - # why must we be explicit about an if statement for torch.jit.is_scripting here? |
130 |
| - # because JIT can't handle Optionals nor fancy conditionals when scripting |
131 |
| - if not torch.jit.is_scripting(): |
132 |
| - _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
133 |
| - else: |
134 |
| - foreach = False |
135 |
| - |
136 |
| - if foreach and torch.jit.is_scripting(): |
137 |
| - raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
| 145 | + if foreach and torch.jit.is_scripting(): |
| 146 | + raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
| 147 | + else: |
| 148 | + foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype |
138 | 149 |
|
139 | 150 | if foreach and not torch.jit.is_scripting():
|
140 | 151 | func = _multi_tensor_sgdw
|
|
0 commit comments