-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Description
Hi;
I'm trying to use grad cam with TransUNet for segmentation. Code at https://github.com/Beckschen/TransUNet
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
model = net#resnet50(pretrained=True)
target_layers = net.segmentation_head #[net.layer4[-1]]
input_tensor = sampled_batch2["image"]# Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!
# We have to specify the target we want to generate the CAM for.
targets = [ClassifierOutputTarget(0)]
# Construct the CAM object once, and then re-use it on many images.
with GradCAM(model=model, target_layers=target_layers) as cam:
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# You can also get the model outputs without having to redo inference
model_outputs = cam.outputs
but I'm getting this error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[65], line 17
14 # Construct the CAM object once, and then re-use it on many images.
15 with GradCAM(model=model, target_layers=target_layers) as cam:
16 # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
---> 17 grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
18 # In this example grayscale_cam has only one image in the batch:
19 grayscale_cam = grayscale_cam[0, :]
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/base_cam.py:186, in BaseCAM.__call__(self, input_tensor, targets, aug_smooth, eigen_smooth)
183 if aug_smooth is True:
184 return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)
--> 186 return self.forward(input_tensor, targets, eigen_smooth)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/base_cam.py:90, in BaseCAM.forward(self, input_tensor, targets, eigen_smooth)
87 if self.compute_input_gradient:
88 input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
---> 90 self.outputs = outputs = self.activations_and_grads(input_tensor)
92 if targets is None:
93 target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/pytorch_grad_cam/activations_and_gradients.py:42, in ActivationsAndGradients.__call__(self, x)
40 self.gradients = []
41 self.activations = []
---> 42 return self.model(x)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
Cell In[6], line 395, in VisionTransformer.forward(self, x)
393 if x.size()[1] == 1:
394 x = x.repeat(1,3,1,1)
--> 395 x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
396 x = self.decoder(x, features)
397 logits = self.segmentation_head(x)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
1558 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1559 args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
1562 if _global_forward_hooks or self._forward_hooks:
1563 for hook_id, hook in (
1564 *_global_forward_hooks.items(),
1565 *self._forward_hooks.items(),
1566 ):
1567 # mark that always called hook is run
Cell In[6], line 258, in Transformer.forward(self, input_ids)
257 def forward(self, input_ids):
--> 258 embedding_output, features = self.embeddings(input_ids)
260 encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
262 return encoded, attn_weights, features
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1561, in Module._call_impl(self, *args, **kwargs)
1558 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1559 args = bw_hook.setup_input_hook(args)
-> 1561 result = forward_call(*args, **kwargs)
1562 if _global_forward_hooks or self._forward_hooks:
1563 for hook_id, hook in (
1564 *_global_forward_hooks.items(),
1565 *self._forward_hooks.items(),
1566 ):
1567 # mark that always called hook is run
Cell In[6], line 156, in Embeddings.forward(self, x)
154 def forward(self, x):
155 if self.hybrid:
--> 156 x, features = self.hybrid_model(x)
157 else:
158 features = None
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1574, in Module._call_impl(self, *args, **kwargs)
1572 hook_result = hook(self, args, kwargs, result)
1573 else:
-> 1574 hook_result = hook(self, args, result)
1576 if hook_result is not None:
1577 result = hook_result
File /scratch/aqa6122/anaconda3/lib/python3.12/site-packages/torchsummary/torchsummary.py:23, in summary.<locals>.register_hook.<locals>.hook(module, input, output)
20 summary[m_key]["input_shape"][0] = batch_size
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
26 summary[m_key]["output_shape"] = list(output.size())
AttributeError: 'list' object has no attribute 'size'
Metadata
Metadata
Assignees
Labels
No labels