Skip to content

Commit 7086b7b

Browse files
committed
Add absolute value to montecarlo plot
1 parent 77314a5 commit 7086b7b

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

survlimepy/survlime_explainer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def plot_montecarlo_weights(
394394
scale_with_data_point: bool = False,
395395
figure_path: Optional[str] = None,
396396
with_colour: bool = True,
397+
absolute_vals: bool = False,
397398
) -> None:
398399
"""Generates explanations for a prediction.
399400
@@ -403,6 +404,7 @@ def plot_montecarlo_weights(
403404
scale_with_data_point (bool): whether to perform the elementwise multiplication between the point to be explained and the coefficients.
404405
figure_path (Optional[str]): path to save the figure.
405406
with_colour (bool): boolean indicating whether the colour palette for positive coefficients is different than thecolour palette for negative coefficients. Default is set to True.
407+
absolute_vals (bool): whether to plot the absolute values of the coefficients.
406408
407409
Returns:
408410
None.
@@ -454,9 +456,25 @@ def plot_montecarlo_weights(
454456
colors_up = {key: val for key, val in zip(median_up.keys(), pal_up)}
455457
colors_down = {key: val for key, val in zip(median_down.keys(), pal_down)}
456458
custom_pal = {**colors_up, **colors_down}
457-
data_reindex = data.reindex(columns=custom_pal.keys())
458-
data_melt = pd.melt(data_reindex)
459-
459+
if absolute_vals:
460+
absolute_order = {**median_up, **median_down}
461+
absolute_order = {key: np.abs(val) for key, val in absolute_order.items()}
462+
absolute_order = dict(
463+
sorted(
464+
absolute_order.items(),
465+
key=lambda item: np.abs(item[1]),
466+
reverse=True,
467+
)
468+
)
469+
custom_pal = {key: custom_pal[key] for key in absolute_order.keys()}
470+
data_reindex = data.reindex(columns=list(custom_pal.keys()))
471+
data_melt = pd.melt(data_reindex)
472+
data_melt.value = np.abs(data_melt.value)
473+
plot_title = "Absolute feature importance"
474+
else:
475+
data_reindex = data.reindex(columns=custom_pal.keys())
476+
data_melt = pd.melt(data_reindex)
477+
plot_title = "Feature importance"
460478
_, ax = plt.subplots(figsize=figsize)
461479
ax.tick_params(labelrotation=90)
462480
if with_colour:
@@ -481,7 +499,7 @@ def plot_montecarlo_weights(
481499
p.yaxis.grid(True)
482500
p.xaxis.grid(True)
483501

484-
p.set_title("Feature importance", fontsize=16, fontweight="bold")
502+
p.set_title(plot_title, fontsize=16, fontweight="bold")
485503

486504
plt.xticks(fontsize=16, rotation=90)
487505
plt.yticks(fontsize=14, rotation=0)

0 commit comments

Comments
 (0)