Skip to content

Commit c5acb4e

Browse files
committed
update readme
1 parent 0ce1f45 commit c5acb4e

File tree

1 file changed

+59
-4
lines changed

1 file changed

+59
-4
lines changed

readme.md

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
| :-------------------------- | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ |
2020
| Tree-Prompt | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/tree_prompt.ipynb), [🗂️](http://csinva.io/imodelsX/treeprompt/treeprompt.html), [🔗](https://github.com/csinva/tree-prompt/tree/main), [📄]() | Explanation<br/>+ Steering | Generates a tree of prompts to<br/>steer an LLM (*Official*) |
2121
| iPrompt | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/iprompt.ipynb), [🗂️](http://csinva.io/imodelsX/iprompt/api.html#imodelsx.iprompt.api.explain_dataset_iprompt), [🔗](https://github.com/csinva/interpretable-autoprompting), [📄](https://arxiv.org/abs/2210.01848) | Explanation<br/>+ Steering | Generates a prompt that<br/>explains patterns in data (*Official*) |
22+
| AutoPrompt | &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation<br/>+ Steering | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)|
2223
| D3 | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/d3.ipynb), [🗂️](http://csinva.io/imodelsX/d3/d3.html#imodelsx.d3.d3.explain_dataset_d3), [🔗](https://github.com/ruiqi-zhong/DescribeDistributionalDifferences), [📄](https://arxiv.org/abs/2201.12323) | Explanation | Explain the difference between two distributions |
2324
| SASC | &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ㅤㅤ[🗂️](https://csinva.io/imodelsX/sasc/api.html), [🔗](https://github.com/microsoft/automated-explanations), [📄](https://arxiv.org/abs/2305.09863) | Explanation | Explain a black-box text module<br/>using an LLM (*Official*) |
24-
| AutoPrompt | &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)|
2525
| Aug-GAM | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb), [🗂️](https://csinva.io/imodelsX/auggam/auggam.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://arxiv.org/abs/2209.11799) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
2626
| Aug-Tree | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb), [🗂️](https://csinva.io/imodelsX/augtree/augtree.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://arxiv.org/abs/2209.11799) | Decision tree | Fit better decision tree using an LLM<br/>to expand features (*Official*) |
2727

@@ -47,10 +47,65 @@
4747
**Demos**: see the [demo notebooks](https://github.com/csinva/imodelsX/tree/master/demo_notebooks)
4848

4949

50+
# Natural-language explanations
51+
52+
### Tree-prompt
53+
```python
54+
from imodelsx import TreePromptClassifier
55+
import datasets
56+
import numpy as np
57+
from sklearn.tree import plot_tree
58+
import matplotlib.pyplot as plt
5059

51-
# Explainable models
60+
# set up data
61+
rng = np.random.default_rng(seed=42)
62+
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
63+
dset_train = dset_train.select(rng.choice(
64+
len(dset_train), size=100, replace=False))
65+
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
66+
dset_val = dset_val.select(rng.choice(
67+
len(dset_val), size=100, replace=False))
68+
69+
# set up arguments
70+
prompts = [
71+
"This movie is",
72+
" Positive or Negative? The movie was",
73+
" The sentiment of the movie was",
74+
" The plot of the movie was really",
75+
" The acting in the movie was",
76+
]
77+
verbalizer = {0: " Negative.", 1: " Positive."}
78+
checkpoint = "gpt2"
5279

53-
# Natural-language explanations
80+
# fit model
81+
m = TreePromptClassifier(
82+
checkpoint=checkpoint,
83+
prompts=prompts,
84+
verbalizer=verbalizer,
85+
cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2',
86+
)
87+
m.fit(dset_train["text"], dset_train["label"])
88+
89+
90+
# compute accuracy
91+
preds = m.predict(dset_val['text'])
92+
print('\nTree-Prompt acc (val) ->',
93+
np.mean(preds == dset_val['label'])) # -> 0.7
94+
95+
# compare to accuracy for individual prompts
96+
for i, prompt in enumerate(prompts):
97+
print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
98+
99+
# visualize decision tree
100+
plot_tree(
101+
m.clf_,
102+
fontsize=10,
103+
feature_names=m.feature_names_,
104+
class_names=list(verbalizer.values()),
105+
filled=True,
106+
)
107+
plt.show()
108+
```
54109

55110
### iPrompt
56111

@@ -108,7 +163,7 @@ explanation_dict = explain_module_sasc(
108163
)
109164
```
110165

111-
### Aug-imodels
166+
# Aug-imodels
112167
Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.
113168

114169
```python

0 commit comments

Comments
 (0)