Skip to content

Commit c6266f2

Browse files
committed
Add missing file
1 parent 3abaed0 commit c6266f2

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

python/fcdd/util/tb.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
from torch.utils.tensorboard import SummaryWriter
3+
from torchvision.utils import make_grid
4+
from matplotlib.pyplot import get_cmap
5+
import numpy as np
6+
7+
class TBLogger:
8+
def __init__(self, log_dir):
9+
self.writer = SummaryWriter(log_dir=log_dir)
10+
11+
def add_scalars(self, loss, lr, epoch):
12+
self.writer.add_scalar('Loss/train', loss, epoch)
13+
self.writer.add_scalar('LR', lr, epoch)
14+
self.writer.flush()
15+
16+
def add_scalars(self, loss, loss_normal, loss_anomal, lr, epoch):
17+
self.writer.add_scalars('Loss/train', {
18+
'total': loss,
19+
'normal': loss_normal,
20+
'anomalous': loss_anomal
21+
}, epoch)
22+
23+
self.writer.add_scalar('LR', lr, epoch)
24+
self.writer.flush()
25+
26+
def add_images(self, inputs: torch.Tensor, gt_maps: (None, torch.Tensor), outputs: torch.Tensor,
27+
normal: bool, epoch: int):
28+
main_tag = 'normal' if normal else 'anomalous'
29+
30+
cmap = get_cmap('jet')
31+
outputs_new = []
32+
outputs = outputs.clone()
33+
34+
def norm_ip(img, min, max):
35+
img.clamp_(min=min, max=max)
36+
img.add_(-min).div_(max - min + 1e-5)
37+
38+
norm_ip(outputs, float(outputs.min()), float(outputs.max()))
39+
for img in outputs.squeeze(dim=1):
40+
outputs_new.append(cmap(img.detach().cpu().numpy())[:, :, :3])
41+
outputs = torch.tensor(outputs_new).permute(0, 3, 1, 2)
42+
43+
for tag, imgs in zip(['inputs', 'gt_maps', 'outputs'], [inputs, gt_maps, outputs]):
44+
if imgs is not None:
45+
batch_size = imgs.size(0)
46+
nrow = int(np.sqrt(batch_size))
47+
grid = make_grid(imgs, nrow=nrow)
48+
self.writer.add_image(main_tag + '/' + tag, grid, epoch)
49+
self.writer.flush()
50+
51+
def add_network(self, model: torch.nn.Module, input_to_model):
52+
self.writer.add_graph(model, input_to_model)
53+
self.writer.flush()
54+
55+
def add_weight_histograms(self, model, epoch):
56+
for name, m in model.named_modules():
57+
if isinstance(m, torch.nn.Conv2d):
58+
self.writer.add_histogram(name + '.weight', m.weight, epoch)
59+
if m.bias is not None:
60+
self.writer.add_histogram(name + '.bias', m.bias, epoch)
61+
self.writer.flush()
62+
63+
def close(self):
64+
self.writer.close()

0 commit comments

Comments
 (0)