|
| 1 | +import torch |
| 2 | +from torch.utils.tensorboard import SummaryWriter |
| 3 | +from torchvision.utils import make_grid |
| 4 | +from matplotlib.pyplot import get_cmap |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +class TBLogger: |
| 8 | + def __init__(self, log_dir): |
| 9 | + self.writer = SummaryWriter(log_dir=log_dir) |
| 10 | + |
| 11 | + def add_scalars(self, loss, lr, epoch): |
| 12 | + self.writer.add_scalar('Loss/train', loss, epoch) |
| 13 | + self.writer.add_scalar('LR', lr, epoch) |
| 14 | + self.writer.flush() |
| 15 | + |
| 16 | + def add_scalars(self, loss, loss_normal, loss_anomal, lr, epoch): |
| 17 | + self.writer.add_scalars('Loss/train', { |
| 18 | + 'total': loss, |
| 19 | + 'normal': loss_normal, |
| 20 | + 'anomalous': loss_anomal |
| 21 | + }, epoch) |
| 22 | + |
| 23 | + self.writer.add_scalar('LR', lr, epoch) |
| 24 | + self.writer.flush() |
| 25 | + |
| 26 | + def add_images(self, inputs: torch.Tensor, gt_maps: (None, torch.Tensor), outputs: torch.Tensor, |
| 27 | + normal: bool, epoch: int): |
| 28 | + main_tag = 'normal' if normal else 'anomalous' |
| 29 | + |
| 30 | + cmap = get_cmap('jet') |
| 31 | + outputs_new = [] |
| 32 | + outputs = outputs.clone() |
| 33 | + |
| 34 | + def norm_ip(img, min, max): |
| 35 | + img.clamp_(min=min, max=max) |
| 36 | + img.add_(-min).div_(max - min + 1e-5) |
| 37 | + |
| 38 | + norm_ip(outputs, float(outputs.min()), float(outputs.max())) |
| 39 | + for img in outputs.squeeze(dim=1): |
| 40 | + outputs_new.append(cmap(img.detach().cpu().numpy())[:, :, :3]) |
| 41 | + outputs = torch.tensor(outputs_new).permute(0, 3, 1, 2) |
| 42 | + |
| 43 | + for tag, imgs in zip(['inputs', 'gt_maps', 'outputs'], [inputs, gt_maps, outputs]): |
| 44 | + if imgs is not None: |
| 45 | + batch_size = imgs.size(0) |
| 46 | + nrow = int(np.sqrt(batch_size)) |
| 47 | + grid = make_grid(imgs, nrow=nrow) |
| 48 | + self.writer.add_image(main_tag + '/' + tag, grid, epoch) |
| 49 | + self.writer.flush() |
| 50 | + |
| 51 | + def add_network(self, model: torch.nn.Module, input_to_model): |
| 52 | + self.writer.add_graph(model, input_to_model) |
| 53 | + self.writer.flush() |
| 54 | + |
| 55 | + def add_weight_histograms(self, model, epoch): |
| 56 | + for name, m in model.named_modules(): |
| 57 | + if isinstance(m, torch.nn.Conv2d): |
| 58 | + self.writer.add_histogram(name + '.weight', m.weight, epoch) |
| 59 | + if m.bias is not None: |
| 60 | + self.writer.add_histogram(name + '.bias', m.bias, epoch) |
| 61 | + self.writer.flush() |
| 62 | + |
| 63 | + def close(self): |
| 64 | + self.writer.close() |
0 commit comments