Skip to content

Commit 4a6864f

Browse files
Cam addition (#7)
* Initialised cam parts * Added new class for CAM * Fixed licence link * Several fixes * Deleted unused libraries in LRP * Deleted unused libraries in Occlusion * Changes mock data paths * Initialised tests * Updated requirements and setup * Added coverage * Added tests * Added cam notebooks * Updated README * Updated badges in README
1 parent ccafc68 commit 4a6864f

21 files changed

+2959
-16
lines changed

.coveragerc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[run]
2+
source = easy_explain
3+
omit =
4+
*/tests/*
5+
*/venv/*
6+
7+
[report]
8+
exclude_lines =
9+
pragma: no cover
10+
def __repr__
11+
if self.debug:
12+
raise AssertionError
13+
raise NotImplementedError
14+
if __name__ == .__main__.:

README.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![GitHub][github_badge]][github_link]
77
[![PyPI][pypi_badge]][pypi_link]
88
[![Download][download_badge]][download_link]
9+
[![Download][total_download_badge]][download_link]
910
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1011
[![Licence][licence_badge]][licence_link]
1112

@@ -51,6 +52,7 @@ There are also other customade algorithms to support other models like the LRP i
5152
Currently, `easy-explain` specializes in two cutting-edge XAI methodologies for images:
5253

5354
- Occlusion: For deep insight into classification model decisions.
55+
- Cam: SmoothGradCAMpp & LayerCAM for explainability on image classification models.
5456
- Layer-wise Relevance Propagation (LRP): Specifically tailored for YoloV8 models, unveiling the decision-making process in object detection tasks.
5557

5658
## Quick Start
@@ -83,6 +85,23 @@ explanation_lrp = lrp.explain(image, cls='your-class', contrastive=False).cpu()
8385
lrp.plot_explanation(frame=image, explanation = explanation_lrp, contrastive=True, cmap='seismic', title='Explanation for your class"')
8486
```
8587

88+
```python
89+
from easy_explain import YOLOv8LRP
90+
91+
model = 'your-model'
92+
image = 'your-image'
93+
94+
trans_params = {"ImageNet_transformation":
95+
{"Resize": {"h": 224,"w": 224},
96+
"Normalize": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}}}
97+
98+
explainer = CAMExplain(model)
99+
100+
input_tensor = explainer.transform_image(img, trans_params["ImageNet_transformation"])
101+
102+
explainer.generate_explanation(img, input_tensor, multiple_layers=["a_layer", "another_layer", "another_layer"])
103+
```
104+
86105
For more information about how to begin have a look at the [examples notebooks](https://github.com/stavrostheocharis/easy_explain/tree/main/examples).
87106

88107
## Examples
@@ -95,6 +114,8 @@ Explore how `easy-explain` can be applied in various scenarios:
95114

96115
![Use Case Example](easy_explain/images/siberian-positive.png "Use Case Example")
97116

117+
![Use Case Example](easy_explain/images/jiraffe-cam-method.png "Use Case Example")
118+
98119
![Use Case Example](easy_explain/images/class-traffic.png "Use Case Example")
99120

100121
## How to contribute?
@@ -118,10 +139,12 @@ Join us in making AI models more interpretable, transparent, and trustworthy wit
118139

119140
[pypi_link]: https://pypi.org/project/easy-explain/
120141

121-
[download_badge]: https://badgen.net/pypi/dm/easy-explain
142+
[download_badge]: https://static.pepy.tech/personalized-badge/easy-explain?period=month&units=international_system&left_color=grey&right_color=green&left_text=Monthly%20Downloads
143+
144+
[total_download_badge]: https://static.pepy.tech/personalized-badge/easy-explain?period=total&units=international_system&left_color=grey&right_color=green&left_text=Total%20Downloads
122145

123146
[download_link]: https://pypi.org/project/easy-explain/#files
124147

125-
[licence_badge]: https://img.shields.io/github/license/stavrostheocharis/easy-explain
148+
[licence_badge]: https://img.shields.io/github/license/stavrostheocharis/easy_explain
126149

127150
[licence_link]: LICENSE

easy_explain/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .methods import YOLOv8LRP, OcclusionExplain
1+
from .methods import YOLOv8LRP, OcclusionExplain, CAMExplain
22

3-
__all__ = ["YOLOv8LRP", "OcclusionExplain"]
3+
__all__ = ["YOLOv8LRP", "OcclusionExplain", "CAMExplain"]
2.22 MB
Loading

easy_explain/methods/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .lrp import YOLOv8LRP
22
from .occlusion import OcclusionExplain
3+
from .cam import CAMExplain
34

4-
__all__ = ["YOLOv8LRP", "OcclusionExplain"]
5+
__all__ = ["YOLOv8LRP", "OcclusionExplain", "CAMExplain"]

easy_explain/methods/cam/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .cam import CAMExplain
2+
3+
__all__ = ["CAMExplain"]

easy_explain/methods/cam/cam.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import torch
2+
from torchcam.methods import SmoothGradCAMpp, LayerCAM
3+
from torchcam.utils import overlay_mask
4+
from torchvision import transforms
5+
import matplotlib.pyplot as plt
6+
from typing import List, Optional, Dict, Any
7+
import logging
8+
from easy_explain.methods.xai_base import ExplainabilityMethod
9+
10+
11+
class CAMExplain(ExplainabilityMethod):
12+
def __init__(self, model: torch.nn.Module):
13+
self.model = model
14+
logging.basicConfig(level=logging.INFO)
15+
16+
def transform_image(
17+
self,
18+
img: torch.Tensor,
19+
trans_params: Dict[str, Dict[str, Any]],
20+
) -> torch.Tensor:
21+
"""
22+
Transforms an image using specified resizing and normalization parameters.
23+
24+
Args:
25+
img (Image.Image): The image to transform.
26+
trans_params (Dict[str, Dict[str, Any]]): Parameters for resizing and normalization.
27+
28+
Returns:
29+
torch.Tensor: The transformed image tensor.
30+
"""
31+
try:
32+
resize_params = trans_params["Resize"]
33+
normalize_params = trans_params["Normalize"]
34+
input_tensor = transforms.functional.normalize(
35+
transforms.functional.resize(
36+
img, (resize_params["h"], resize_params["w"])
37+
)
38+
/ 255.0,
39+
normalize_params["mean"],
40+
normalize_params["std"],
41+
)
42+
return input_tensor
43+
44+
except Exception as e:
45+
logging.error(f"Error transforming image: {e}")
46+
raise
47+
48+
def get_multiple_layers_result(
49+
self,
50+
img: torch.Tensor,
51+
input_tensor: torch.Tensor,
52+
layers: List[str],
53+
alpha: float,
54+
):
55+
"""
56+
Visualizes CAMs for multiple layers and their fused result.
57+
58+
Args:
59+
img (torch.Tensor): The original image tensor.
60+
input_tensor (torch.Tensor): The tensor to input to the model.
61+
layers (List[str]): List of layer names to visualize CAMs for.
62+
alpha (float): Alpha value for blending CAMs on the original image.
63+
"""
64+
try:
65+
# Retrieve the CAM from several layers at the same time
66+
cam_extractor = LayerCAM(self.model, layers)
67+
# Preprocess your data and feed it to the model
68+
output = self.model(input_tensor.unsqueeze(0))
69+
# Retrieve the CAM by passing the class index and the model output
70+
cams = cam_extractor(output.squeeze(0).argmax().item(), output)
71+
logging.info("Successfully retrieved CAMs for multiple layers")
72+
73+
cam_per_layer_list = []
74+
# Get the cam per target layer provided
75+
for cam in cams:
76+
cam_per_layer_list.append(cam.shape)
77+
78+
logging.info(f"The cams per target layer are: {cam_per_layer_list}")
79+
80+
# Raw CAM
81+
_, axes = plt.subplots(1, len(cam_extractor.target_names))
82+
for id, name, cam in zip(
83+
range(len(cam_extractor.target_names)), cam_extractor.target_names, cams
84+
):
85+
axes[id].imshow(cam.squeeze(0).numpy())
86+
axes[id].axis("off")
87+
axes[id].set_title(name)
88+
plt.show()
89+
90+
fused_cam = cam_extractor.fuse_cams(cams)
91+
# Plot the raw version
92+
plt.imshow(fused_cam.squeeze(0).numpy())
93+
plt.axis("off")
94+
plt.title(" + ".join(cam_extractor.target_names))
95+
plt.show()
96+
# Plot the overlayed version
97+
result = overlay_mask(
98+
transforms.functional.to_pil_image(img),
99+
transforms.functional.to_pil_image(fused_cam, mode="F"),
100+
alpha=alpha,
101+
)
102+
plt.imshow(result)
103+
plt.axis("off")
104+
plt.title(" + ".join(cam_extractor.target_names))
105+
plt.show()
106+
cam_extractor.remove_hooks()
107+
108+
except Exception as e:
109+
logging.error(f"Error retrieving CAMs for multiple layers: {e}")
110+
raise
111+
112+
def get_localisation_mask(self, input_tensor: torch.Tensor, img: torch.Tensor):
113+
"""
114+
Generates and visualizes localization masks based on CAMs.
115+
116+
Args:
117+
input_tensor (torch.Tensor): The tensor input to the model.
118+
img (torch.Tensor): The original image tensor.
119+
"""
120+
try:
121+
# Retrieve CAM for differnet layers at the same time
122+
cam_extractor = LayerCAM(self.model)
123+
output = self.model(input_tensor.unsqueeze(0))
124+
cams = cam_extractor(output.squeeze(0).argmax().item(), output)
125+
126+
# Transformations
127+
resized_cams = [
128+
transforms.functional.resize(
129+
transforms.functional.to_pil_image(cam.squeeze(0)), img.shape[-2:]
130+
)
131+
for cam in cams
132+
]
133+
segmaps = [
134+
transforms.functional.to_pil_image(
135+
(
136+
transforms.functional.resize(cam, img.shape[-2:]).squeeze(0)
137+
>= 0.5
138+
).to(dtype=torch.float32)
139+
)
140+
for cam in cams
141+
]
142+
143+
# Plots
144+
for name, cam, seg in zip(
145+
cam_extractor.target_names, resized_cams, segmaps
146+
):
147+
_, axes = plt.subplots(1, 2)
148+
axes[0].imshow(cam)
149+
axes[0].axis("off")
150+
axes[0].set_title(name)
151+
axes[1].imshow(seg)
152+
axes[1].axis("off")
153+
axes[1].set_title(name)
154+
plt.show()
155+
cam_extractor.remove_hooks()
156+
157+
except Exception as e:
158+
logging.error(f"Error generating localization masks: {e}")
159+
raise
160+
161+
def generate_explanation(
162+
self,
163+
img: torch.Tensor,
164+
input_tensor: torch.Tensor,
165+
target_layer: Optional[str] = None,
166+
localisation_mask: bool = True,
167+
multiple_layers: List[str] = [],
168+
alpha=0.5,
169+
):
170+
"""
171+
Extracts and visualizes CAMs for a target layer or multiple layers.
172+
173+
Args:
174+
img (torch.Tensor): The original image tensor.
175+
input_tensor (torch.Tensor): The tensor input to the model.
176+
target_layer (Optional[str]): The target layer for CAM visualization.
177+
localisation_mask (bool): Whether to generate localization masks.
178+
multiple_layers (List[str]): Layers for multi-layer CAM visualization.
179+
alpha (float): Alpha value for blending CAMs on the original image.
180+
"""
181+
try:
182+
cam_extractor = SmoothGradCAMpp(self.model, target_layer=target_layer)
183+
output = self.model(input_tensor.unsqueeze(0))
184+
# Get the CAM giving the class index and output
185+
cams = cam_extractor(output.squeeze(0).argmax().item(), output)
186+
187+
cam_per_layer_list = []
188+
# Get the cam per target layer provided
189+
for cam in cams:
190+
cam_per_layer_list.append(cam.shape)
191+
192+
logging.info(f"The cams per target layer are: {cam_per_layer_list}")
193+
194+
# The raw CAM
195+
for name, cam in zip(cam_extractor.target_names, cams):
196+
plt.imshow(cam.squeeze(0).numpy())
197+
plt.axis("off")
198+
plt.title(name)
199+
plt.show()
200+
201+
# Overlayed on the image
202+
for name, cam in zip(cam_extractor.target_names, cams):
203+
result = overlay_mask(
204+
transforms.functional.to_pil_image(img),
205+
transforms.functional.to_pil_image(cam.squeeze(0), mode="F"),
206+
alpha=alpha,
207+
)
208+
plt.imshow(result)
209+
plt.axis("off")
210+
plt.title(name)
211+
plt.show()
212+
213+
cam_extractor.remove_hooks()
214+
215+
if localisation_mask:
216+
self.get_localisation_mask(input_tensor, img)
217+
218+
if len(multiple_layers) > 0:
219+
self.get_multiple_layers_result(
220+
img, input_tensor, multiple_layers, alpha
221+
)
222+
223+
except Exception as e:
224+
logging.error(f"Error extracting CAM: {e}")
225+
raise

easy_explain/methods/lrp/yolov8/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from scipy.ndimage import zoom
3-
from sklearn.model_selection import train_test_split
43
import numpy as np
54

65

@@ -60,7 +59,6 @@ def scale_mask(mask, shape):
6059

6160

6261
class LayerRelevance(torch.Tensor):
63-
6462
"""
6563
LayerRelevance(relevance=None, contrastive=False, print_decimals=5)
6664

easy_explain/methods/occlusion/occlusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from captum.attr import visualization as viz
88
from captum.attr import Occlusion
99
import json
10-
from typing import Union, List, Dict, Any
11-
import itertools
12-
from easy_explain.methods.occlusion.xai_base import ExplainabilityMethod
10+
from typing import Union, List, Dict
11+
from easy_explain.methods.xai_base import ExplainabilityMethod
1312

1413

1514
class OcclusionExplain(ExplainabilityMethod):
Loading
2.15 MB
Loading

0 commit comments

Comments
 (0)