|
28 | 28 | from ODFA import ODFA
|
29 | 29 | from utils import save_network
|
30 | 30 | version = torch.__version__
|
31 |
| -#fp16 |
32 |
| -try: |
33 |
| - from apex.fp16_utils import * |
34 |
| - from apex import amp |
35 |
| - from apex.optimizers import FusedSGD |
36 |
| -except ImportError: # will be 3.x series |
37 |
| - print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') |
38 |
| - |
39 | 31 | from pytorch_metric_learning import losses, miners #pip install pytorch-metric-learning
|
40 | 32 |
|
41 | 33 | ######################################################################
|
|
56 | 48 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay. More Regularization Smaller Weight.')
|
57 | 49 | parser.add_argument('--total_epoch', default=60, type=int, help='total training epoch')
|
58 | 50 | parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50%% memory' )
|
| 51 | +parser.add_argument('--bf16', action='store_true', help='use bfloat16 instead of float32, which will save about 50%% memory' ) |
59 | 52 | parser.add_argument('--cosine', action='store_true', help='use cosine lrRate' )
|
60 | 53 | parser.add_argument('--FSGD', action='store_true', help='use fused sgd, which will speed up trainig slightly. apex is needed.' )
|
61 | 54 | parser.add_argument('--wa', action='store_true', help='use weight average' )
|
|
94 | 87 | opt.wa = True #DG will enable swa.
|
95 | 88 |
|
96 | 89 | fp16 = opt.fp16
|
| 90 | +bf16 = opt.bf16 |
| 91 | +if fp16: |
| 92 | + dtype16 = torch.float16 |
| 93 | +elif bf16: |
| 94 | + dtype16 = torch.bfloat16 |
| 95 | + |
97 | 96 | data_dir = opt.data_dir
|
98 | 97 | name = opt.name
|
99 | 98 | str_ids = opt.gpu_ids.split(',')
|
@@ -221,7 +220,7 @@ def fliplr(img):
|
221 | 220 | img_flip = img.index_select(3,inv_idx)
|
222 | 221 | return img_flip
|
223 | 222 |
|
224 |
| -def train_model(model, criterion, optimizer, scheduler, num_epochs=25): |
| 223 | +def train_model(model, criterion, optimizer, scheduler, scaler, num_epochs=25): |
225 | 224 | since = time.time()
|
226 | 225 |
|
227 | 226 | #best_model_wts = model.state_dict()
|
@@ -298,6 +297,9 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
298 | 297 | if phase == 'val':
|
299 | 298 | with torch.no_grad():
|
300 | 299 | outputs = model(inputs)
|
| 300 | + elif opt.bf16 or opt.fp16: |
| 301 | + with torch.amp.autocast(device_type='cuda',dtype=dtype16): |
| 302 | + outputs = model(inputs) |
301 | 303 | else:
|
302 | 304 | outputs = model(inputs)
|
303 | 305 |
|
@@ -371,7 +373,12 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
371 | 373 | inputs1 = inputs1.cuda().detach()
|
372 | 374 | inputs2 = inputs2.cuda().detach()
|
373 | 375 | # use memory in vivo loss (https://arxiv.org/abs/1912.11164)
|
374 |
| - outputs1 = model(inputs1) |
| 376 | + if bf16 or fp16: |
| 377 | + with torch.amp.autocast(device_type='cuda', dtype=dtype16): |
| 378 | + outputs1 = model(inputs1) |
| 379 | + else: |
| 380 | + outputs1 = model(inputs1) |
| 381 | + |
375 | 382 | if return_feature:
|
376 | 383 | outputs1, _ = outputs1
|
377 | 384 | elif opt.PCB:
|
@@ -403,12 +410,13 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
403 | 410 | print(loss, warm_up)
|
404 | 411 |
|
405 | 412 | if phase == 'train':
|
406 |
| - if fp16: # we use optimier to backward loss |
407 |
| - with amp.scale_loss(loss, optimizer) as scaled_loss: |
408 |
| - scaled_loss.backward() |
| 413 | + if bf16 or fp16: # we use optimier to backward loss |
| 414 | + scaler.scale(loss).backward() |
| 415 | + scaler.step(optimizer) # a safety optimizer.step() |
| 416 | + scaler.update() |
409 | 417 | else:
|
410 | 418 | loss.backward()
|
411 |
| - optimizer.step() |
| 419 | + optimizer.step() |
412 | 420 | # statistics
|
413 | 421 | if int(version[0])>0 or int(version[2]) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0
|
414 | 422 | running_loss += loss.item() * now_batch_size
|
@@ -619,12 +627,6 @@ def draw_curve(current_epoch):
|
619 | 627 |
|
620 | 628 | criterion = nn.CrossEntropyLoss()
|
621 | 629 |
|
622 |
| -if fp16: |
623 |
| - #model = network_to_half(model) |
624 |
| - #optimizer_ft = FP16_Optimizer(optimizer_ft, static_loss_scale = 128.0) |
625 |
| - model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level = "O1") |
626 |
| - |
627 |
| - |
| 630 | +scaler = torch.cuda.amp.GradScaler() |
628 | 631 | model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
|
629 |
| - num_epochs=opt.total_epoch) |
630 |
| - |
| 632 | + scaler, num_epochs=opt.total_epoch) |
0 commit comments