Skip to content

Commit a5d88a4

Browse files
author
um1
committed
make a probability
1 parent c7cc09b commit a5d88a4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
377377
elif opt.PCB:
378378
for i in range(num_part):
379379
part[i] = outputs1[i]
380-
outputs1 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
380+
outputs1 = (sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5]))/6
381381

382382
swa_model.eval()
383383
with torch.no_grad():
@@ -388,7 +388,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
388388
elif opt.PCB:
389389
for i in range(num_part):
390390
part[i] = outputs2[i]
391-
outputs2 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
391+
outputs2 = (sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5]))/6
392392

393393
#supervised via teacher like dino. previous use sm(outputs1 + outputs2)
394394
kl_loss = nn.KLDivLoss(reduction='batchmean')

0 commit comments

Comments
 (0)