Skip to content

Commit 1dd2ba1

Browse files
authored
Added amp scale to vision classification template (#342)
* Added amp scale to vision classification template * removed tralining spaces
1 parent c9dd782 commit 1dd2ba1

File tree

1 file changed

+6
-3
lines changed
  • src/templates/template-vision-classification

1 file changed

+6
-3
lines changed

src/templates/template-vision-classification/trainers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ignite.distributed as idist
44
import torch
55
from ignite.engine import DeterministicEngine, Engine, Events
6-
from torch.cuda.amp import autocast
6+
from torch.cuda.amp import autocast, GradScaler
77
from torch.nn import Module
88
from torch.optim import Optimizer
99
from torch.utils.data import DistributedSampler, Sampler
@@ -17,6 +17,8 @@ def setup_trainer(
1717
device: Union[str, torch.device],
1818
train_sampler: Sampler,
1919
) -> Union[Engine, DeterministicEngine]:
20+
scaler = GradScaler(enabled=config.use_amp)
21+
2022
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
2123
model.train()
2224

@@ -27,9 +29,10 @@ def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
2729
outputs = model(samples)
2830
loss = loss_fn(outputs, targets)
2931

30-
loss.backward()
31-
optimizer.step()
3232
optimizer.zero_grad()
33+
scaler.scale(loss).backward()
34+
scaler.step(optimizer)
35+
scaler.update()
3336

3437
train_loss = loss.item()
3538
engine.state.metrics = {

0 commit comments

Comments
 (0)