Skip to content

How to use grad-cam with TransUNet #559

@alqurri

Description

@alqurri

Hi;

I'm trying to use grad cam with TransUNet for segmentation. Code at https://github.com/Beckschen/TransUNet

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50

model = net#resnet50(pretrained=True)
target_layers = net.segmentation_head #[net.layer4[-1]]
input_tensor = sampled_batch2["image"]# Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!

# We have to specify the target we want to generate the CAM for.
targets = [ClassifierOutputTarget(0)]

# Construct the CAM object once, and then re-use it on many images.
with GradCAM(model=model, target_layers=target_layers) as cam:
  # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
  # In this example grayscale_cam has only one image in the batch:
  grayscale_cam = grayscale_cam[0, :]
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
  # You can also get the model outputs without having to redo inference
  model_outputs = cam.outputs

but I'm getting this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[65], line 17
     14 # Construct the CAM object once, and then re-use it on many images.
     15 with GradCAM(model=model, target_layers=target_layers) as cam:
     16   # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
---> 17   grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
     18   # In this example grayscale_cam has only one image in the batch:
     19   grayscale_cam = grayscale_cam[0, :]

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/base_cam.py:186, in BaseCAM.__call__(self, input_tensor, targets, aug_smooth, eigen_smooth)
    183 if aug_smooth is True:
    184     return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)
--> 186 return self.forward(input_tensor, targets, eigen_smooth)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/base_cam.py:90, in BaseCAM.forward(self, input_tensor, targets, eigen_smooth)
     87 if self.compute_input_gradient:
     88     input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
---> 90 self.outputs = outputs = self.activations_and_grads(input_tensor)
     92 if targets is None:
     93     target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/activations_and_gradients.py:42, in ActivationsAndGradients.__call__(self, x)
     40 self.gradients = []
     41 self.activations = []
---> 42 return self.model(x)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

Cell In[6], line 395, in VisionTransformer.forward(self, x)
    393 if x.size()[1] == 1:
    394     x = x.repeat(1,3,1,1)
--> 395 x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
    396 x = self.decoder(x, features)
    397 logits = self.segmentation_head(x)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
   1558     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1559     args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
   1562 if _global_forward_hooks or self._forward_hooks:
   1563     for hook_id, hook in (
   1564         *_global_forward_hooks.items(),
   1565         *self._forward_hooks.items(),
   1566     ):
   1567         # mark that always called hook is run

Cell In[6], line 258, in Transformer.forward(self, input_ids)
    257 def forward(self, input_ids):
--> 258     embedding_output, features = self.embeddings(input_ids)
    260     encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
    262     return encoded, attn_weights, features

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
   1558     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1559     args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
   1562 if _global_forward_hooks or self._forward_hooks:
   1563     for hook_id, hook in (
   1564         *_global_forward_hooks.items(),
   1565         *self._forward_hooks.items(),
   1566     ):
   1567         # mark that always called hook is run

Cell In[6], line 156, in Embeddings.forward(self, x)
    154 def forward(self, x):
    155     if self.hybrid:
--> 156         x, features = self.hybrid_model(x)
    157     else:
    158         features = None

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1574, in Module._call_impl(self, *args, **kwargs)
   1572     hook_result = hook(self, args, kwargs, result)
   1573 else:
-> 1574     hook_result = hook(self, args, result)
   1576 if hook_result is not None:
   1577     result = hook_result

File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torchsummary/torchsummary.py:23, in summary.<locals>.register_hook.<locals>.hook(module, input, output)
     20 summary[m_key]["input_shape"][0] = batch_size
     21 if isinstance(output, (list, tuple)):
     22     summary[m_key]["output_shape"] = [
---> 23         [-1] + list(o.size())[1:] for o in output
     24     ]
     25 else:
     26     summary[m_key]["output_shape"] = list(output.size())

AttributeError: 'list' object has no attribute 'size'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions