Skip to content

Commit 1a88bab

Browse files
author
zzd@um4
committed
fix usam; add swa
1 parent ec7cfb6 commit 1a88bab

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def weights_init_classifier(m):
2727

2828
class USAM(nn.Module):
2929
#Joint Representation Learning and Keypoint Detection for Cross-view Geo-localization. TIP2022
30-
def __init__(self, kernel_size=3, padding=1, polish=True):
30+
def __init__(self, kernel_size=3, padding=1, polish=False):
3131
super(USAM, self).__init__()
3232

3333
kernel = torch.ones((kernel_size, kernel_size))
@@ -123,8 +123,8 @@ def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=Fals
123123
self.model = model_ft
124124
self.circle = circle
125125
self.classifier = ClassBlock(2048, class_num, droprate, linear=linear_num, return_f = circle)
126+
self.usam = usam
126127
if usam:
127-
self.usam = usam
128128
self.usam_1 = USAM()
129129
self.usam_2 = USAM()
130130

test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import scipy.io
1818
import yaml
1919
import math
20+
from torch.optim import swa_utils
2021
from tqdm import tqdm
2122
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_swinv2, ft_net_efficient, ft_net_NAS, ft_net_convnext, PCB, PCB_test
2223
from utils import fuse_all_conv_bn
@@ -167,9 +168,15 @@ def load_network(network):
167168
print("Compiling model...")
168169
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
169170
torch.set_float32_matmul_precision('high')
170-
network = torch.compile(network, mode="default", dynamic=True) # pytorch 2.0
171+
network.cuda()
172+
network = torch.compile(network, mode="reduce-overhead", dynamic = True) # pytorch 2.0
173+
if 'average' in opt.which_epoch: # load averaged model.
174+
network = swa_utils.AveragedModel(network)
171175
network.load_state_dict(torch.load(save_path))
172-
176+
if 'average' in opt.which_epoch:
177+
print("We average %d snapshots"%network.n_averaged)
178+
#swa_utils.update_bn(dataloaders['query'], network, device='cuda:0')
179+
network = network.module
173180
return network
174181

175182

train.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
import os
1818
import collections
19+
from torch.optim import swa_utils
1920
from tqdm import tqdm
2021
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_swinv2, ft_net_convnext, ft_net_efficient, ft_net_NAS, PCB
2122
from random_erasing import RandomErasing
@@ -57,6 +58,7 @@
5758
parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50%% memory' )
5859
parser.add_argument('--cosine', action='store_true', help='use cosine lrRate' )
5960
parser.add_argument('--FSGD', action='store_true', help='use fused sgd, which will speed up trainig slightly. apex is needed.' )
61+
parser.add_argument('--wa', action='store_true', help='use weight average' )
6062
# backbone
6163
parser.add_argument('--linear_num', default=512, type=int, help='feature dimension: 512 or default or 0 (linear=False)')
6264
parser.add_argument('--stride', default=2, type=int, help='stride')
@@ -88,6 +90,9 @@
8890

8991
opt = parser.parse_args()
9092

93+
if opt.DG:
94+
opt.wa = True #DG will enable swa.
95+
9196
fp16 = opt.fp16
9297
data_dir = opt.data_dir
9398
name = opt.name
@@ -221,6 +226,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
221226

222227
#best_model_wts = model.state_dict()
223228
#best_acc = 0.0
229+
wa_flag = opt.wa
224230
warm_up = 0.1 # We start from the 0.1*lrRate
225231
warm_iteration = round(dataset_sizes['train']/opt.batchsize)*opt.warm_epoch # first 5 epoch
226232
embedding_size = model.classifier.linear_num
@@ -244,6 +250,12 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
244250
for epoch in range(num_epochs):
245251
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
246252
# print('-' * 10)
253+
254+
if opt.wa and wa_flag and epoch >= num_epochs*0.1:
255+
wa_flag = False
256+
swa_model = swa_utils.AveragedModel(model)
257+
swa_model.avg_fn = swa_utils.get_ema_avg_fn(decay=0.996)
258+
print('start weight avg')
247259

248260
# Each epoch has a training and validation phase
249261
for phase in ['train', 'val']:
@@ -286,8 +298,6 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
286298
else:
287299
outputs = model(inputs)
288300

289-
290-
291301
if opt.adv>0 and iter%opt.aiter==0:
292302
inputs_adv = ODFA(model, inputs)
293303
outputs_adv = model(inputs_adv)
@@ -365,17 +375,22 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
365375
for i in range(num_part):
366376
part[i] = outputs1[i]
367377
outputs1 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
368-
outputs2 = model(inputs2)
378+
379+
swa_model.eval()
380+
with torch.no_grad():
381+
outputs2 = swa_model(inputs2) #stop gradient like dino
382+
outputs2 = outputs2.detach()
383+
369384
if return_feature:
370385
outputs2, _ = outputs2
371386
elif opt.PCB:
372387
for i in range(num_part):
373388
part[i] = outputs2[i]
374389
outputs2 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
375390

376-
mean_pred = sm(outputs1 + outputs2)
391+
#supervised via teacher like dino. previous use sm(outputs1 + outputs2)
377392
kl_loss = nn.KLDivLoss(reduction='batchmean')
378-
reg= (kl_loss(log_sm(outputs2) , mean_pred) + kl_loss(log_sm(outputs1) , mean_pred))/2
393+
reg= (kl_loss(log_sm(outputs2), sm(outputs1)) + kl_loss(log_sm(outputs1) , sm(outputs2)))/2
379394
loss += 0.01*reg
380395
del inputs1, inputs2
381396
#print(0.01*reg)
@@ -419,6 +434,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
419434
pbar.set_postfix(ordered_dict=ordered_dict)
420435
pbar.close()
421436

437+
if phase == 'train' and opt.wa and epoch >= num_epochs*0.1:
438+
swa_model.update_parameters(model)
439+
swa_utils.update_bn(dataloaders['train'], swa_model, device='cuda:0')
440+
422441
y_loss[phase].append(epoch_loss)
423442
y_err[phase].append(1.0-epoch_acc)
424443
# deep copy the model
@@ -449,6 +468,11 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
449468
else:
450469
save_network(model, opt.name, 'last')
451470

471+
if opt.wa:
472+
save_network( swa_model, opt.name, 'average')
473+
swa_utils.update_bn(dataloaders['train'], swa_model, device='cuda:0')
474+
save_network( swa_model, opt.name, 'average_bn')
475+
452476
return model
453477

454478

@@ -511,6 +535,7 @@ def draw_curve(current_epoch):
511535

512536
if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 and one gpu
513537
torch.set_float32_matmul_precision('high')
538+
torch._dynamo.config.automatic_dynamic_shapes = True
514539
print("Compiling model... The first epoch may be slow, which is expected!")
515540
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
516541
model = torch.compile(model, mode="reduce-overhead", dynamic = True) # pytorch 2.0

0 commit comments

Comments
 (0)