|
19 | 19 | | :-------------------------- | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ |
|
20 | 20 | | Tree-Prompt | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/tree_prompt.ipynb), [🗂️](http://csinva.io/imodelsX/treeprompt/treeprompt.html), [🔗](https://github.com/csinva/tree-prompt/tree/main), [📄]() | Explanation<br/>+ Steering | Generates a tree of prompts to<br/>steer an LLM (*Official*) |
|
21 | 21 | | iPrompt | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/iprompt.ipynb), [🗂️](http://csinva.io/imodelsX/iprompt/api.html#imodelsx.iprompt.api.explain_dataset_iprompt), [🔗](https://github.com/csinva/interpretable-autoprompting), [📄](https://arxiv.org/abs/2210.01848) | Explanation<br/>+ Steering | Generates a prompt that<br/>explains patterns in data (*Official*) |
|
| 22 | +| AutoPrompt | ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation<br/>+ Steering | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)| |
22 | 23 | | D3 | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/d3.ipynb), [🗂️](http://csinva.io/imodelsX/d3/d3.html#imodelsx.d3.d3.explain_dataset_d3), [🔗](https://github.com/ruiqi-zhong/DescribeDistributionalDifferences), [📄](https://arxiv.org/abs/2201.12323) | Explanation | Explain the difference between two distributions |
|
23 | 24 | | SASC | ㅤㅤ[🗂️](https://csinva.io/imodelsX/sasc/api.html), [🔗](https://github.com/microsoft/automated-explanations), [📄](https://arxiv.org/abs/2305.09863) | Explanation | Explain a black-box text module<br/>using an LLM (*Official*) |
|
24 |
| -| AutoPrompt | ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)| |
25 | 25 | | Aug-GAM | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb), [🗂️](https://csinva.io/imodelsX/auggam/auggam.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://arxiv.org/abs/2209.11799) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
|
26 | 26 | | Aug-Tree | [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb), [🗂️](https://csinva.io/imodelsX/augtree/augtree.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://arxiv.org/abs/2209.11799) | Decision tree | Fit better decision tree using an LLM<br/>to expand features (*Official*) |
|
27 | 27 |
|
|
47 | 47 | **Demos**: see the [demo notebooks](https://github.com/csinva/imodelsX/tree/master/demo_notebooks)
|
48 | 48 |
|
49 | 49 |
|
| 50 | +# Natural-language explanations |
| 51 | + |
| 52 | +### Tree-prompt |
| 53 | +```python |
| 54 | +from imodelsx import TreePromptClassifier |
| 55 | +import datasets |
| 56 | +import numpy as np |
| 57 | +from sklearn.tree import plot_tree |
| 58 | +import matplotlib.pyplot as plt |
50 | 59 |
|
51 |
| -# Explainable models |
| 60 | +# set up data |
| 61 | +rng = np.random.default_rng(seed=42) |
| 62 | +dset_train = datasets.load_dataset('rotten_tomatoes')['train'] |
| 63 | +dset_train = dset_train.select(rng.choice( |
| 64 | + len(dset_train), size=100, replace=False)) |
| 65 | +dset_val = datasets.load_dataset('rotten_tomatoes')['validation'] |
| 66 | +dset_val = dset_val.select(rng.choice( |
| 67 | + len(dset_val), size=100, replace=False)) |
| 68 | + |
| 69 | +# set up arguments |
| 70 | +prompts = [ |
| 71 | + "This movie is", |
| 72 | + " Positive or Negative? The movie was", |
| 73 | + " The sentiment of the movie was", |
| 74 | + " The plot of the movie was really", |
| 75 | + " The acting in the movie was", |
| 76 | +] |
| 77 | +verbalizer = {0: " Negative.", 1: " Positive."} |
| 78 | +checkpoint = "gpt2" |
52 | 79 |
|
53 |
| -# Natural-language explanations |
| 80 | +# fit model |
| 81 | +m = TreePromptClassifier( |
| 82 | + checkpoint=checkpoint, |
| 83 | + prompts=prompts, |
| 84 | + verbalizer=verbalizer, |
| 85 | + cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2', |
| 86 | +) |
| 87 | +m.fit(dset_train["text"], dset_train["label"]) |
| 88 | + |
| 89 | + |
| 90 | +# compute accuracy |
| 91 | +preds = m.predict(dset_val['text']) |
| 92 | +print('\nTree-Prompt acc (val) ->', |
| 93 | + np.mean(preds == dset_val['label'])) # -> 0.7 |
| 94 | + |
| 95 | +# compare to accuracy for individual prompts |
| 96 | +for i, prompt in enumerate(prompts): |
| 97 | + print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51 |
| 98 | + |
| 99 | +# visualize decision tree |
| 100 | +plot_tree( |
| 101 | + m.clf_, |
| 102 | + fontsize=10, |
| 103 | + feature_names=m.feature_names_, |
| 104 | + class_names=list(verbalizer.values()), |
| 105 | + filled=True, |
| 106 | +) |
| 107 | +plt.show() |
| 108 | +``` |
54 | 109 |
|
55 | 110 | ### iPrompt
|
56 | 111 |
|
@@ -108,7 +163,7 @@ explanation_dict = explain_module_sasc(
|
108 | 163 | )
|
109 | 164 | ```
|
110 | 165 |
|
111 |
| -### Aug-imodels |
| 166 | +# Aug-imodels |
112 | 167 | Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.
|
113 | 168 |
|
114 | 169 | ```python
|
|
0 commit comments