16
16
import time
17
17
import os
18
18
import collections
19
+ from torch .optim import swa_utils
19
20
from tqdm import tqdm
20
21
from model import ft_net , ft_net_dense , ft_net_hr , ft_net_swin , ft_net_swinv2 , ft_net_convnext , ft_net_efficient , ft_net_NAS , PCB
21
22
from random_erasing import RandomErasing
57
58
parser .add_argument ('--fp16' , action = 'store_true' , help = 'use float16 instead of float32, which will save about 50%% memory' )
58
59
parser .add_argument ('--cosine' , action = 'store_true' , help = 'use cosine lrRate' )
59
60
parser .add_argument ('--FSGD' , action = 'store_true' , help = 'use fused sgd, which will speed up trainig slightly. apex is needed.' )
61
+ parser .add_argument ('--wa' , action = 'store_true' , help = 'use weight average' )
60
62
# backbone
61
63
parser .add_argument ('--linear_num' , default = 512 , type = int , help = 'feature dimension: 512 or default or 0 (linear=False)' )
62
64
parser .add_argument ('--stride' , default = 2 , type = int , help = 'stride' )
88
90
89
91
opt = parser .parse_args ()
90
92
93
+ if opt .DG :
94
+ opt .wa = True #DG will enable swa.
95
+
91
96
fp16 = opt .fp16
92
97
data_dir = opt .data_dir
93
98
name = opt .name
@@ -221,6 +226,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
221
226
222
227
#best_model_wts = model.state_dict()
223
228
#best_acc = 0.0
229
+ wa_flag = opt .wa
224
230
warm_up = 0.1 # We start from the 0.1*lrRate
225
231
warm_iteration = round (dataset_sizes ['train' ]/ opt .batchsize )* opt .warm_epoch # first 5 epoch
226
232
embedding_size = model .classifier .linear_num
@@ -244,6 +250,12 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
244
250
for epoch in range (num_epochs ):
245
251
print ('Epoch {}/{}' .format (epoch , num_epochs - 1 ))
246
252
# print('-' * 10)
253
+
254
+ if opt .wa and wa_flag and epoch >= num_epochs * 0.1 :
255
+ wa_flag = False
256
+ swa_model = swa_utils .AveragedModel (model )
257
+ swa_model .avg_fn = swa_utils .get_ema_avg_fn (decay = 0.996 )
258
+ print ('start weight avg' )
247
259
248
260
# Each epoch has a training and validation phase
249
261
for phase in ['train' , 'val' ]:
@@ -286,8 +298,6 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
286
298
else :
287
299
outputs = model (inputs )
288
300
289
-
290
-
291
301
if opt .adv > 0 and iter % opt .aiter == 0 :
292
302
inputs_adv = ODFA (model , inputs )
293
303
outputs_adv = model (inputs_adv )
@@ -365,17 +375,22 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
365
375
for i in range (num_part ):
366
376
part [i ] = outputs1 [i ]
367
377
outputs1 = part [0 ] + part [1 ] + part [2 ] + part [3 ] + part [4 ] + part [5 ]
368
- outputs2 = model (inputs2 )
378
+
379
+ swa_model .eval ()
380
+ with torch .no_grad ():
381
+ outputs2 = swa_model (inputs2 ) #stop gradient like dino
382
+ outputs2 = outputs2 .detach ()
383
+
369
384
if return_feature :
370
385
outputs2 , _ = outputs2
371
386
elif opt .PCB :
372
387
for i in range (num_part ):
373
388
part [i ] = outputs2 [i ]
374
389
outputs2 = part [0 ] + part [1 ] + part [2 ] + part [3 ] + part [4 ] + part [5 ]
375
390
376
- mean_pred = sm (outputs1 + outputs2 )
391
+ #supervised via teacher like dino. previous use sm(outputs1 + outputs2)
377
392
kl_loss = nn .KLDivLoss (reduction = 'batchmean' )
378
- reg = (kl_loss (log_sm (outputs2 ) , mean_pred ) + kl_loss (log_sm (outputs1 ) , mean_pred ))/ 2
393
+ reg = (kl_loss (log_sm (outputs2 ), sm ( outputs1 )) + kl_loss (log_sm (outputs1 ) , sm ( outputs2 ) ))/ 2
379
394
loss += 0.01 * reg
380
395
del inputs1 , inputs2
381
396
#print(0.01*reg)
@@ -419,6 +434,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
419
434
pbar .set_postfix (ordered_dict = ordered_dict )
420
435
pbar .close ()
421
436
437
+ if phase == 'train' and opt .wa and epoch >= num_epochs * 0.1 :
438
+ swa_model .update_parameters (model )
439
+ swa_utils .update_bn (dataloaders ['train' ], swa_model , device = 'cuda:0' )
440
+
422
441
y_loss [phase ].append (epoch_loss )
423
442
y_err [phase ].append (1.0 - epoch_acc )
424
443
# deep copy the model
@@ -449,6 +468,11 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
449
468
else :
450
469
save_network (model , opt .name , 'last' )
451
470
471
+ if opt .wa :
472
+ save_network ( swa_model , opt .name , 'average' )
473
+ swa_utils .update_bn (dataloaders ['train' ], swa_model , device = 'cuda:0' )
474
+ save_network ( swa_model , opt .name , 'average_bn' )
475
+
452
476
return model
453
477
454
478
@@ -511,6 +535,7 @@ def draw_curve(current_epoch):
511
535
512
536
if torch .cuda .get_device_capability ()[0 ]> 6 and len (opt .gpu_ids )== 1 and int (version [0 ])> 1 : # should be >=7 and one gpu
513
537
torch .set_float32_matmul_precision ('high' )
538
+ torch ._dynamo .config .automatic_dynamic_shapes = True
514
539
print ("Compiling model... The first epoch may be slow, which is expected!" )
515
540
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
516
541
model = torch .compile (model , mode = "reduce-overhead" , dynamic = True ) # pytorch 2.0
0 commit comments