Skip to content

Commit ce229d6

Browse files
committed
refactor auggam -> auglinear
1 parent 78b43b1 commit ce229d6

File tree

9 files changed

+109
-142
lines changed

9 files changed

+109
-142
lines changed

demo_notebooks/aug_imodels.ipynb

Lines changed: 78 additions & 115 deletions
Large diffs are not rendered by default.

imodelsx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
.. include:: ../readme.md
33
"""
44

5-
from .auggam.auggam import AugGAMClassifier, AugGAMRegressor
5+
from .auglinear.auglinear import AugLinearClassifier, AugLinearRegressor
66
from .augtree.augtree import AugTreeClassifier, AugTreeRegressor
77
from .linear_finetune import LinearFinetuneClassifier, LinearFinetuneRegressor
88
from .linear_ngram import LinearNgramClassifier, LinearNgramRegressor
File renamed without changes.

imodelsx/auggam/auggam.py renamed to imodelsx/auglinear/auglinear.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
2-
Simple scikit-learn interface for Emb-GAM.
2+
Simple scikit-learn interface for Aug-Linear.
33
4-
5-
Aug-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models
6-
Chandan Singh & Jianfeng Gao
4+
Augmenting Interpretable Models with LLMs during Training
5+
Chandan Singh, Armin Askari, Rich Caruana, Jianfeng Gao
76
https://arxiv.org/abs/2209.11799
87
"""
98
from numpy.typing import ArrayLike
@@ -15,7 +14,7 @@
1514
from sklearn.utils.validation import check_is_fitted
1615
from sklearn.preprocessing import StandardScaler
1716
import transformers
18-
import imodelsx.auggam.embed
17+
import imodelsx.auglinear.embed
1918
from tqdm import tqdm
2019
import os
2120
import os.path
@@ -29,7 +28,7 @@
2928
device = "cuda" if torch.cuda.is_available() else "cpu"
3029

3130

32-
class AugGAM(BaseEstimator):
31+
class AugLinear(BaseEstimator):
3332
def __init__(
3433
self,
3534
checkpoint: str = "bert-base-uncased",
@@ -41,11 +40,10 @@ def __init__(
4140
random_state=None,
4241
normalize_embs=False,
4342
cache_embs_dir: str = None,
44-
cache_coefs_dir: str = None,
4543
fit_with_ngram_decomposition=True,
4644
instructor_prompt=None,
4745
):
48-
"""AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
46+
"""AugLinear Class - use either AugLinearClassifier or AugLinearRegressor rather than initializing this class directly.
4947
5048
Parameters
5149
----------
@@ -68,7 +66,7 @@ def __init__(
6866
cache_embs_dir: str = None,
6967
if not None, directory to save embeddings into
7068
fit_with_ngram_decomposition
71-
whether to fit to emb-gam style (using sum of embeddings of each ngram)
69+
whether to fit to aug-linear style (using sum of embeddings of each ngram)
7270
if False, fits a typical model and uses ngram decomposition only for prediction / testing
7371
Usually, setting this to False will considerably impede performance
7472
instructor_prompt
@@ -166,7 +164,7 @@ def fit(
166164
def _get_embs_summed(self, X, model, tokenizer_embeddings, batch_size):
167165
embs = []
168166
for x in tqdm(X):
169-
emb = imodelsx.auggam.embed.embed_and_sum_function(
167+
emb = imodelsx.auglinear.embed.embed_and_sum_function(
170168
x,
171169
model=model,
172170
ngrams=self.ngrams,
@@ -276,6 +274,7 @@ def normalize_embs(embs, renormalize_embs):
276274
linear_coef = embs @ coef_embs
277275

278276
# save coefs
277+
linear_coef = linear_coef.squeeze()
279278
self.coefs_dict_ = {
280279
**coefs_dict_old,
281280
**{ngrams_list[i]: linear_coef[i] for i in range(len(ngrams_list))},
@@ -302,7 +301,7 @@ def _get_embs(self, ngrams_list, model, tokenizer_embeddings, batch_size):
302301
embs = np.vstack(embs).squeeze()
303302

304303
else:
305-
embs = imodelsx.auggam.embed.embed_and_sum_function(
304+
embs = imodelsx.auglinear.embed.embed_and_sum_function(
306305
ngrams_list,
307306
model=model,
308307
ngrams=None,
@@ -347,12 +346,14 @@ def predict(self, X, warn=True):
347346
"""For regression returns continuous output.
348347
For classification, returns discrete output.
349348
"""
349+
350350
check_is_fitted(self)
351351
preds = self._predict_cached(X, warn=warn)
352352
if isinstance(self, RegressorMixin):
353353
return preds
354354
elif isinstance(self, ClassifierMixin):
355-
if preds.ndim > 1: # multiclass classification
355+
# multiclass classification
356+
if preds.ndim > 1:
356357
return np.argmax(preds, axis=1)
357358
else:
358359
return (preds + self.linear.intercept_ > 0).astype(int)
@@ -398,14 +399,14 @@ def _predict_cached(self, X, warn):
398399
For better performance, call cache_linear_coefs on the test dataset \
399400
before calling predict."
400401
)
401-
return np.array(preds)
402+
return np.array(preds).squeeze()
402403

403404

404-
class AugGAMRegressor(AugGAM, RegressorMixin):
405+
class AugLinearRegressor(AugLinear, RegressorMixin):
405406
...
406407

407408

408-
class AugGAMClassifier(AugGAM, ClassifierMixin):
409+
class AugLinearClassifier(AugLinear, ClassifierMixin):
409410
...
410411

411412

File renamed without changes.

readme.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
| AutoPrompt | ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation<br/>+ Steering | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)|
2323
| D3 | [🗂️](http://csinva.io/imodelsX/d3/d3.html#imodelsx.d3.d3.explain_dataset_d3), [🔗](https://github.com/ruiqi-zhong/DescribeDistributionalDifferences), [📄](https://arxiv.org/abs/2201.12323), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/d3.ipynb) | Explanation | Explain the difference between two distributions |
2424
| SASC | ㅤㅤ[🗂️](https://csinva.io/imodelsX/sasc/api.html), [🔗](https://github.com/microsoft/automated-explanations), [📄](https://arxiv.org/abs/2305.09863) | Explanation | Explain a black-box text module<br/>using an LLM (*Official*) |
25-
| Aug-GAM | [🗂️](https://csinva.io/imodelsX/auggam/auggam.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
25+
| Aug-Linear | [🗂️](https://csinva.io/imodelsX/auglinear/auglinear.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
2626
| Aug-Tree | [🗂️](https://csinva.io/imodelsX/augtree/augtree.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Decision tree | Fit better decision tree using an LLM<br/>to expand features (*Official*) |
2727

2828
<p align="center">
@@ -167,7 +167,7 @@ explanation_dict = explain_module_sasc(
167167
Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.
168168

169169
```python
170-
from imodelsx import AugGAMClassifier, AugTreeClassifier, AugGAMRegressor, AugTreeRegressor
170+
from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
171171
import datasets
172172
import numpy as np
173173

@@ -178,7 +178,7 @@ dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
178178
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))
179179

180180
# fit model
181-
m = AugGAMClassifier(
181+
m = AugLinearClassifier(
182182
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
183183
ngrams=2, # use bigrams
184184
)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
setuptools.setup(
3030
name="imodelsx",
31-
version="0.4.2",
32-
author="Chandan Singh, John X. Morris, Armin Askari, Divyanshu Aggarwal, Aliyah Hsu",
31+
version="0.4.0",
32+
author="Chandan Singh, John X. Morris, Armin Askari, Divyanshu Aggarwal, Aliyah Hsu, Yuntian Deng",
3333
author_email="chansingh@microsoft.com",
3434
description="Library to explain a dataset in natural language.",
3535
long_description=long_description,
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset, AugGAMClassifier
1+
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset, AugLinearClassifier
2+
23

34
def test_auggam():
4-
m = AugGAMClassifier()
5+
m = AugLinearClassifier()

tests/test_pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from imodelsx import AugGAMClassifier
1+
from imodelsx import AugLinearClassifier
22
import datasets
33
import numpy as np
44

@@ -11,7 +11,7 @@
1111
len(dset_val), size=10, replace=False))
1212

1313
# fit model
14-
m = AugGAMClassifier(
14+
m = AugLinearClassifier(
1515
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
1616
ngrams=2,
1717
all_ngrams=True, # also use lower-order ngrams
@@ -27,8 +27,10 @@
2727
# check results when varying batch size
2828
m.fit(dset['text'], dset['label'], batch_size=16)
2929
preds_check = m.predict(dset_val['text'])
30-
assert np.allclose(preds, preds_check), 'predictions should be same when varying batch size'
31-
assert np.allclose(np.array(list(m.coefs_dict_.values())), coefs_orig), 'coefs should be same when varying batch size'
30+
assert np.allclose(
31+
preds, preds_check), 'predictions should be same when varying batch size'
32+
assert np.allclose(np.array(list(m.coefs_dict_.values())),
33+
coefs_orig), 'coefs should be same when varying batch size'
3234

3335
# interpret
3436
print('Total ngram coefficients: ', len(m.coefs_dict_))

0 commit comments

Comments
 (0)