Skip to content

Commit 110a22f

Browse files
committed
add shap.DeepExplainer
1 parent 570c7d0 commit 110a22f

File tree

7 files changed

+125
-41
lines changed

7 files changed

+125
-41
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@ DisCERN requires the following packages:<br>
2222
| Attribution Explainer | scikit-learn | TensorFlow/Keras | PyTorch |
2323
|-----------------------|--------------|------------------|---------|
2424
| LIME | &check; | &check; | N/A |
25-
| SHAP | &check; | &check; | N/A |
25+
| SHAP | &check; shap.TreeExplainer | &check; shap.DeepExplainer | N/A |
2626
| Integrated Gradients | &cross; | &check; | N/A |
2727

28-
2928
## Getting Started with DisCERN
3029

31-
An example of the Adult Income dataset using RandomForest and Keras Deep Neural Net classifiers are <a href="/tests/adult_income.py">here</a>
30+
Binary Classification example on the Adult Income dataset using RandomForest and Keras Deep Neural Net classifiers are <a href="/tests/adult_income.py">here</a>
3231

33-
<!--Multi-class Classification example using the Cancer risk dataset and RandomForest classifier <a href="/tests/cancer.py">here</a>-->
32+
Multi-class Classification example on the Cancer risk dataset using RandomForest and Keras Deep Neural Net classifiers are <a href="/tests/cancer.py">here</a>
3433

3534
## Citing
3635

discern/discern_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def init_data(self, train_data, train_labels, feature_names, labels, **kwargs):
3939
raise ValueError("DisCERN requires feature names.")
4040
if len(self.labels) == 0:
4141
raise ValueError("DisCERN requires class names.")
42-
if len(self.labels) != len(set(self.train_labels)):
43-
raise ValueError("Mismatch between class names and number of classes.")
42+
# if len(self.labels) != len(set(self.train_labels)):
43+
# raise ValueError("Mismatch between class names and number of classes.")
4444
if len(self.feature_names) != self.train_data.shape[1]:
4545
raise ValueError("Mismatch between number of features and training data.")
4646

discern/discern_tabular.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def init_rel(self):
4444
if self.attrib == 'LIME':
4545
self.feature_attrib = FeatureAttributionLIME(self.model, self.feature_names, train_data=self.train_data, labels=self.labels)
4646
elif self.attrib == 'SHAP':
47-
self.feature_attrib = FeatureAttributionSHAP(self.model, self.feature_names)
47+
self.feature_attrib = FeatureAttributionSHAP(self.model, self.feature_names, train_data=self.train_data)
4848
elif self.attrib == 'IntG':
4949
self.feature_attrib = FeatureAttributionIntG(self.model)
5050
else:

discern/fa_shap.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
from discern.fa_base import FeatureAttribution
22
import shap
33
import pandas as pd
4+
from sklearn.base import ClassifierMixin
5+
import tensorflow as tf
6+
import numpy as np
47

58
class FeatureAttributionSHAP(FeatureAttribution):
69

7-
def __init__(self, model, feature_names):
10+
def __init__(self, model, feature_names, train_data):
811
super().__init__(model)
912
self.feature_names = feature_names
10-
self.shap_explainer = shap.TreeExplainer(self.model)
13+
self.train_data = train_data
14+
if isinstance(self.model, ClassifierMixin):
15+
self.shap_explainer = shap.TreeExplainer(self.model)
16+
elif isinstance(self.model, tf.keras.Model):
17+
self.shap_explainer = shap.DeepExplainer(self.model, self.train_data)
18+
1119

1220
def explain_instance(self, query, query_label=None, nun=None):
13-
i_exp = pd.DataFrame([query], columns=self.feature_names)
14-
shap_values = self.shap_explainer.shap_values(i_exp)
15-
return [(i,w) for i,w in enumerate(shap_values[int(query_label)][0])]
21+
if isinstance(self.model, ClassifierMixin):
22+
i_exp = pd.DataFrame([query], columns=self.feature_names)
23+
shap_values = self.shap_explainer.shap_values(i_exp)
24+
return [(i,w) for i,w in enumerate(shap_values[int(query_label)][0])]
25+
elif isinstance(self.model, tf.keras.Model):
26+
shap_values = self.shap_explainer.shap_values(np.array([query]))
27+
print(shap_values)
28+
return [(i,w) for i,w in enumerate(shap_values[int(query_label)][0])]
29+
30+

discern/util.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def nun(data, labels, query, query_label, cf_label):
99
top_labels = [labels[j] for j in top_indices[:sample_size]]
1010
for i, lab in enumerate(top_labels):
1111
if query_label != lab and lab == cf_label:
12-
nun_index = i
13-
break
14-
return data[top_indices[nun_index]], top_labels[nun_index]
15-
12+
nun_index = i
13+
return data[top_indices[nun_index]], top_labels[nun_index]
14+
raise Exception('NUN not found.')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import setuptools
22

3-
VERSION_STR = "0.0.26"
3+
VERSION_STR = "0.0.27"
44

55
with open("README.md", "r") as fh:
66
long_description = fh.read()

tests/cancer_risk.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from sklearn.preprocessing import MinMaxScaler
55
from sklearn.ensemble import RandomForestClassifier
66
from sklearn.metrics import accuracy_score
7+
import os
8+
import numpy as np
9+
import tensorflow as tf
710

8-
def test_cancer_risk():
9-
data_df = pd.read_csv('lung_cancer.csv')
11+
def sklearn_test(attrib):
12+
data_df = pd.read_csv(os.path.join(os.path.dirname(__file__), 'lung_cancer.csv'))
1013
data_df = data_df.replace({'Level': {'Low': 0, 'Medium': 1, 'High': 2}})
1114
data_df = data_df.replace({'Gender': {2: 0}})
1215
data_df = data_df.replace({'Alcohol use': {2: 0}})
@@ -28,37 +31,105 @@ def test_cancer_risk():
2831
print("Train test split complete!")
2932

3033
scaler = MinMaxScaler()
31-
x_train= scaler.fit_transform(x_train)
32-
x_test = scaler.transform(x_test)
34+
x_train_norm = scaler.fit_transform(x_train)
35+
x_test_norm = scaler.transform(x_test)
3336
print("Data transform complete!")
3437

35-
rfx = RandomForestClassifier(n_estimators=500)
38+
rfx = RandomForestClassifier(n_estimators=100)
3639
rfx.fit(x_train, y_train)
3740
print(accuracy_score(y_test, rfx.predict(x_test)))
3841
print("Training classifier complete!")
3942

40-
x_test = x_test[:10]
41-
y_test = rfx.predict(x_test[:10])
43+
test_instance = x_test_norm[10]
44+
test_label = rfx.predict([x_test_norm[10]])[0]
4245
cat_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
46+
imm_indices = [0, 1, 2]
47+
discern = DisCERNTabular(rfx, attrib)
48+
discern.init_data(x_train_norm, y_train, [c for c in df.columns if c!='Level'], ['Low', 'Medium', 'High'], cat_feature_indices=cat_indices, immutable_feature_indices=imm_indices)
4349

44-
sparsity = []
45-
proximity = []
46-
discern = DisCERNTabular(rfx, 'LIME', 'Q')
47-
discern.init_data(x_train, y_train, [c for c in df.columns if c!='Level'], ['Low', 'Medium', 'High'], cat_feature_indices=cat_indices)
50+
cf, cf_label, s, p = discern.find_cf(test_instance, test_label, cf_label=0)
51+
print('---------------------sklearn-'+attrib+'---------------------')
52+
print(cf, cf_label)
53+
print(test_instance, test_label)
54+
print("Sparsity: ",s, "Proximity: ", p)
4855

49-
for idx in range(len(x_test)):
50-
if y_test[idx] == 0:
51-
continue
52-
cf, s, p = discern.find_cf(x_test[idx], y_test[idx], desired_class='Low')
53-
print(s)
54-
print(p)
55-
sparsity.append(s)
56-
proximity.append(p)
5756

58-
_sparsity = sum(sparsity)/len(sparsity)
59-
_proximity = sum(proximity)/(len(proximity)*_sparsity)
60-
print(_sparsity)
61-
print(_proximity)
57+
def keras_test(attrib):
58+
data_df = pd.read_csv(os.path.join(os.path.dirname(__file__), 'lung_cancer.csv'))
59+
data_df = data_df.replace({'Level': {'Low': 0, 'Medium': 1, 'High': 2}})
60+
data_df = data_df.replace({'Gender': {2: 0}})
61+
data_df = data_df.replace({'Alcohol use': {2: 0}})
62+
data_df = data_df.replace({'Dust Allergy': {2: 0}})
63+
data_df = data_df.replace({'Smoking': {2: 0}})
64+
data_df = data_df.replace({'Chest Pain': {2: 0}})
65+
data_df = data_df.replace({'Fatigue': {2: 0}})
66+
data_df = data_df.replace({'Shortness of Breath': {2: 0}})
67+
data_df = data_df.replace({'Wheezing': {2: 0}})
68+
data_df = data_df.replace({'Swallowing Difficulty': {2: 0}})
69+
data_df = data_df.replace({'Cough': {2: 0}})
70+
data_df = data_df.replace({'chronic Lung Disease': {2: 0}})
71+
print("Reading data complete!")
72+
73+
df = data_df.copy()
74+
x = df.loc[:, df.columns != 'Level'].values
75+
y = df['Level'].values
76+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=1)
77+
print("Train test split complete!")
78+
79+
scaler = MinMaxScaler()
80+
x_train_norm = scaler.fit_transform(x_train)
81+
x_test_norm = scaler.transform(x_test)
82+
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=len(df['Level'].unique()), dtype='float32')
83+
y_test_cat = tf.keras.utils.to_categorical(y_test, num_classes=len(df['Level'].unique()), dtype='float32')
84+
print("Data transform complete!")
85+
86+
inputs = tf.keras.Input(shape=(x_train_norm.shape[-1],))
87+
hidden1 = tf.keras.layers.Dense(64, activation='relu')(inputs)
88+
hidden2 = tf.keras.layers.Dense(64, activation='relu')(hidden1)
89+
outputs = tf.keras.layers.Dense(len(df['Level'].unique()), activation='softmax')(hidden2)
90+
91+
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="model")
92+
93+
model.compile(
94+
loss='categorical_crossentropy',
95+
optimizer='Adam',
96+
metrics=['accuracy'])
97+
98+
model.fit(x_train_norm, y_train_cat, validation_data=(x_test, y_test_cat), batch_size=32, epochs=5, verbose=0)
99+
print("Training classifier complete: ", accuracy_score(y_test, model.predict(x_test_norm).argmax(axis=-1)))
100+
101+
test_instance = x_test_norm[12]
102+
test_label = model.predict(np.array([x_test_norm[12]])).argmax(axis=-1)[0]
103+
104+
cat_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
105+
imm_indices = [0, 1, 2]
106+
discern = DisCERNTabular(model, attrib)
107+
print('labels', set( model.predict(x_train_norm).argmax(axis=-1)))
108+
discern.init_data(x_train_norm, model.predict(x_train_norm).argmax(axis=-1), [c for c in df.columns if c!='Level'], ['Low', 'Medium', 'High'], cat_feature_indices=cat_indices, immutable_feature_indices=imm_indices)
62109

110+
cf, cf_label, s, p = discern.find_cf(test_instance, test_label, cf_label=0)
111+
print('---------------------sklearn-'+attrib+'---------------------')
112+
print(cf, cf_label)
113+
print(test_instance, test_label)
114+
print("Sparsity: ",s, "Proximity: ", p)
63115

64-
test_cancer_risk()
116+
try:
117+
sklearn_test('LIME')
118+
except:
119+
None
120+
try:
121+
sklearn_test('SHAP')
122+
except:
123+
None
124+
try:
125+
keras_test('LIME')
126+
except:
127+
None
128+
try:
129+
keras_test('SHAP')
130+
except:
131+
None
132+
try:
133+
keras_test('IntG')
134+
except:
135+
None

0 commit comments

Comments
 (0)