Skip to content

Commit 54bcdc7

Browse files
author
um1
committed
support bf16
1 parent 6f70650 commit 54bcdc7

File tree

3 files changed

+27
-30
lines changed

3 files changed

+27
-30
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ A tiny, friendly, strong baseline code for Object-reID (based on [pytorch](https
88

99
- **Strong.** It is consistent with the new baseline result in several top-conference works, e.g., [Joint Discriminative and Generative Learning for Person Re-identification(CVPR19)](https://arxiv.org/abs/1904.07223), [Beyond Part Models: Person Retrieval with Refined Part Pooling(ECCV18)](https://arxiv.org/abs/1711.09349), [Camera Style Adaptation for Person Re-identification(CVPR18)](https://arxiv.org/abs/1711.10295). We arrived Rank@1=88.24%, mAP=70.68% only with softmax loss.
1010

11-
- **Small.** With fp16 (supported by Nvidia apex), our baseline could be trained with only 2GB GPU memory.
11+
- **Small.** With bf16/fp16 (supported by native pytorch), our baseline could be trained with only 2GB GPU memory.
1212

1313
- **Friendly.** You may use the off-the-shelf options to apply many state-of-the-art tricks in one line.
1414
Besides, if you are new to object re-ID, you may check out our **[Tutorial](https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial)** first (8 min read) :+1: .
@@ -61,12 +61,12 @@ Share to
6161
Now we have supported:
6262

6363
### Training
64+
- bf16 and fp16 (Float16) to save GPU memory based on native pytorch (replace apex).
6465
- Running the code on Google Colab with Free GPU. Check [Here](https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/colab) (Thanks to @ronghao233)
6566
- [DG-Market](https://github.com/NVlabs/DG-Net#dg-market) (10x Large Synthetic Dataset from Market **CVPR 2019 Oral**)
6667
- [Swin Transformer](https://github.com/microsoft/Swin-Transformer) / [EfficientNet](https://github.com/lukemelas/EfficientNet-PyTorch) / [HRNet](https://github.com/HRNet)
6768
- ResNet/ResNet-ibn/DenseNet
6869
- Circle Loss, Triplet Loss, Contrastive Loss, Sphere Loss, Lifted Loss, Arcface, Cosface and Instance Loss
69-
- Float16 to save GPU memory based on [apex](https://github.com/NVIDIA/apex)
7070
- Part-based Convolutional Baseline(PCB)
7171
- Random Erasing
7272
- Linear Warm-up
@@ -193,7 +193,7 @@ Submission DDL is **1 Jan 2025**.
193193
2018 & 2017 News
194194
</b></summary>
195195

196-
**What's new:** FP16 has been added. It can be used by simply added `--fp16`. You need to install [apex](https://github.com/NVIDIA/apex) and update your pytorch to 1.0.
196+
**What's new:** FP16 has been added. It can be used by simply added `--fp16`. You need to update your pytorch to 2.0.
197197

198198
Float16 could save about 50% GPU memory usage without accuracy drop. **Our baseline could be trained with only 2GB GPU memory.**
199199
```bash

test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_swinv2, ft_net_efficient, ft_net_NAS, ft_net_convnext, PCB, PCB_test
2323
from utils import fuse_all_conv_bn
2424
version = torch.__version__
25-
#fp16
26-
try:
27-
from apex.fp16_utils import *
28-
except ImportError: # will be 3.x series
29-
print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
3025

3126
######################################################################
3227
# Options

train.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@
2828
from ODFA import ODFA
2929
from utils import save_network
3030
version = torch.__version__
31-
#fp16
32-
try:
33-
from apex.fp16_utils import *
34-
from apex import amp
35-
from apex.optimizers import FusedSGD
36-
except ImportError: # will be 3.x series
37-
print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
38-
3931
from pytorch_metric_learning import losses, miners #pip install pytorch-metric-learning
4032

4133
######################################################################
@@ -56,6 +48,7 @@
5648
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay. More Regularization Smaller Weight.')
5749
parser.add_argument('--total_epoch', default=60, type=int, help='total training epoch')
5850
parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50%% memory' )
51+
parser.add_argument('--bf16', action='store_true', help='use bfloat16 instead of float32, which will save about 50%% memory' )
5952
parser.add_argument('--cosine', action='store_true', help='use cosine lrRate' )
6053
parser.add_argument('--FSGD', action='store_true', help='use fused sgd, which will speed up trainig slightly. apex is needed.' )
6154
parser.add_argument('--wa', action='store_true', help='use weight average' )
@@ -94,6 +87,12 @@
9487
opt.wa = True #DG will enable swa.
9588

9689
fp16 = opt.fp16
90+
bf16 = opt.bf16
91+
if fp16:
92+
dtype16 = torch.float16
93+
elif bf16:
94+
dtype16 = torch.bfloat16
95+
9796
data_dir = opt.data_dir
9897
name = opt.name
9998
str_ids = opt.gpu_ids.split(',')
@@ -221,7 +220,7 @@ def fliplr(img):
221220
img_flip = img.index_select(3,inv_idx)
222221
return img_flip
223222

224-
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
223+
def train_model(model, criterion, optimizer, scheduler, scaler, num_epochs=25):
225224
since = time.time()
226225

227226
#best_model_wts = model.state_dict()
@@ -298,6 +297,9 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
298297
if phase == 'val':
299298
with torch.no_grad():
300299
outputs = model(inputs)
300+
elif opt.bf16 or opt.fp16:
301+
with torch.amp.autocast(device_type='cuda',dtype=dtype16):
302+
outputs = model(inputs)
301303
else:
302304
outputs = model(inputs)
303305

@@ -371,7 +373,12 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
371373
inputs1 = inputs1.cuda().detach()
372374
inputs2 = inputs2.cuda().detach()
373375
# use memory in vivo loss (https://arxiv.org/abs/1912.11164)
374-
outputs1 = model(inputs1)
376+
if bf16 or fp16:
377+
with torch.amp.autocast(device_type='cuda', dtype=dtype16):
378+
outputs1 = model(inputs1)
379+
else:
380+
outputs1 = model(inputs1)
381+
375382
if return_feature:
376383
outputs1, _ = outputs1
377384
elif opt.PCB:
@@ -403,12 +410,13 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
403410
print(loss, warm_up)
404411

405412
if phase == 'train':
406-
if fp16: # we use optimier to backward loss
407-
with amp.scale_loss(loss, optimizer) as scaled_loss:
408-
scaled_loss.backward()
413+
if bf16 or fp16: # we use optimier to backward loss
414+
scaler.scale(loss).backward()
415+
scaler.step(optimizer) # a safety optimizer.step()
416+
scaler.update()
409417
else:
410418
loss.backward()
411-
optimizer.step()
419+
optimizer.step()
412420
# statistics
413421
if int(version[0])>0 or int(version[2]) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0
414422
running_loss += loss.item() * now_batch_size
@@ -619,12 +627,6 @@ def draw_curve(current_epoch):
619627

620628
criterion = nn.CrossEntropyLoss()
621629

622-
if fp16:
623-
#model = network_to_half(model)
624-
#optimizer_ft = FP16_Optimizer(optimizer_ft, static_loss_scale = 128.0)
625-
model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level = "O1")
626-
627-
630+
scaler = torch.cuda.amp.GradScaler()
628631
model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
629-
num_epochs=opt.total_epoch)
630-
632+
scaler, num_epochs=opt.total_epoch)

0 commit comments

Comments
 (0)