Skip to content

Commit 5d29431

Browse files
committed
clean up auggam, add support for llama
1 parent 5d743bf commit 5d29431

File tree

5 files changed

+392
-159
lines changed

5 files changed

+392
-159
lines changed

imodelsx/auggam/auggam.py

Lines changed: 97 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@
2222
import os.path
2323
import warnings
2424
import pickle as pkl
25+
from os.path import join
2526
import torch
26-
from torch.utils.data import DataLoader
27-
from datasets import Dataset
27+
from transformers import LlamaModel, LlamaTokenizer
2828
from sklearn.exceptions import ConvergenceWarning
29-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
29+
30+
device = "cuda" if torch.cuda.is_available() else "cpu"
3031

3132

3233
class AugGAM(BaseEstimator):
3334
def __init__(
3435
self,
35-
checkpoint: str = 'bert-base-uncased',
36-
layer: str = 'last_hidden_state',
36+
checkpoint: str = "bert-base-uncased",
37+
layer: str = "last_hidden_state",
3738
ngrams: int = 2,
3839
all_ngrams: bool = False,
3940
min_frequency: int = 1,
@@ -43,7 +44,7 @@ def __init__(
4344
fit_with_ngram_decomposition=True,
4445
instructor_prompt=None,
4546
):
46-
'''AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
47+
"""AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
4748
4849
Parameters
4950
----------
@@ -69,7 +70,7 @@ def __init__(
6970
Usually, setting this to False will considerably impede performance
7071
instructor_prompt
7172
if not None, use instructor-xl with this prompt
72-
'''
73+
"""
7374
self.checkpoint = checkpoint
7475
self.ngrams = ngrams
7576
if tokenizer_ngrams == None:
@@ -84,12 +85,16 @@ def __init__(
8485
self.fit_with_ngram_decomposition = fit_with_ngram_decomposition
8586
self.instructor_prompt = instructor_prompt
8687

87-
def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
88-
cache_linear_coefs: bool = True,
89-
cache_embs_dir: str = None,
90-
batch_size: int = 8
91-
):
92-
'''Extract embeddings then fit linear model
88+
def fit(
89+
self,
90+
X: ArrayLike,
91+
y: ArrayLike,
92+
verbose=True,
93+
cache_linear_coefs: bool = True,
94+
cache_embs_dir: str = None,
95+
batch_size: int = 8,
96+
):
97+
"""Extract embeddings then fit linear model
9398
9499
Parameters
95100
----------
@@ -101,7 +106,7 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
101106
if not None, directory to save embeddings into
102107
batch_size, optional
103108
if not None, batch size to pass while calculating embeddings
104-
'''
109+
"""
105110

106111
# metadata
107112
if isinstance(self, ClassifierMixin):
@@ -111,25 +116,29 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
111116

112117
# set up model
113118
if verbose:
114-
print('initializing model...')
119+
print("initializing model...")
115120
model, tokenizer_embeddings = self._get_model_and_tokenizer()
116121

117122
# get embs
118123
if verbose:
119-
print('calculating embeddings...')
120-
embs = self._get_embs_summed(
121-
X, model, tokenizer_embeddings, batch_size)
124+
print("calculating embeddings...")
125+
if cache_embs_dir is not None and os.path.exists(
126+
os.path.join(cache_embs_dir, "embs.pkl")
127+
):
128+
embs = pkl.load(open(os.path.join(cache_embs_dir, "embs.pkl"), "rb"))
129+
else:
130+
embs = self._get_embs_summed(X, model, tokenizer_embeddings, batch_size)
131+
if cache_embs_dir is not None:
132+
os.makedirs(cache_embs_dir, exist_ok=True)
133+
pkl.dump(embs, open(os.path.join(cache_embs_dir, "embs.pkl"), "wb"))
122134
if self.normalize_embs:
123135
self.normalizer = StandardScaler()
124136
embs = self.normalizer.fit_transform(embs)
125-
if cache_embs_dir is not None:
126-
os.makedirs(cache_embs_dir, exist_ok=True)
127-
pkl.dump(embs, open(os.path.join(cache_embs_dir, 'embs.pkl'), 'wb'))
128137

129138
# train linear
130139
warnings.filterwarnings("ignore", category=ConvergenceWarning)
131140
if verbose:
132-
print('training linear model...')
141+
print("training linear model...")
133142
if isinstance(self, ClassifierMixin):
134143
self.linear = LogisticRegressionCV()
135144
elif isinstance(self, RegressorMixin):
@@ -139,7 +148,7 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
139148
# cache linear coefs
140149
if cache_linear_coefs:
141150
if verbose:
142-
print('caching linear coefs...')
151+
print("caching linear coefs...")
143152
self.cache_linear_coefs(X, model, tokenizer_embeddings)
144153

145154
return self
@@ -158,28 +167,44 @@ def _get_embs_summed(self, X, model, tokenizer_embeddings, batch_size):
158167
all_ngrams=self.all_ngrams,
159168
fit_with_ngram_decomposition=self.fit_with_ngram_decomposition,
160169
instructor_prompt=self.instructor_prompt,
161-
batch_size=batch_size
170+
batch_size=batch_size,
162171
)
163-
embs.append(emb['embs'])
172+
embs.append(emb["embs"])
164173
return np.array(embs).squeeze() # num_examples x embedding_size
165174

166175
def _get_model_and_tokenizer(self):
167-
if self.checkpoint.startswith('hkunlp/instructor-xl'):
176+
if self.checkpoint.startswith("hkunlp/instructor-xl"):
168177
from InstructorEmbedding import INSTRUCTOR
178+
169179
model = INSTRUCTOR(self.checkpoint).to(device)
170180
tokenizer_embeddings = None
181+
elif 'llama' in self.checkpoint:
182+
# path to extracted llama weights
183+
LLAMA_DIR = join(os.path.expanduser("~"), "llama")
184+
tokenizer_embeddings = transformers.LlamaTokenizer.from_pretrained(
185+
join(LLAMA_DIR, self.checkpoint)
186+
)
187+
model = transformers.LlamaModel.from_pretrained(
188+
join(LLAMA_DIR, self.checkpoint),
189+
device_map="auto",
190+
torch_dtype=torch.float16,
191+
)
171192
else:
172-
model = transformers.AutoModel.from_pretrained(
173-
self.checkpoint).to(device)
193+
model = transformers.AutoModel.from_pretrained(self.checkpoint).to(device)
174194
tokenizer_embeddings = transformers.AutoTokenizer.from_pretrained(
175-
self.checkpoint)
195+
self.checkpoint
196+
)
176197
return model, tokenizer_embeddings
177198

178-
def cache_linear_coefs(self, X: ArrayLike, model=None,
179-
tokenizer_embeddings=None,
180-
renormalize_embs: bool = False,
181-
batch_size: int = 8,
182-
verbose: bool = True):
199+
def cache_linear_coefs(
200+
self,
201+
X: ArrayLike,
202+
model=None,
203+
tokenizer_embeddings=None,
204+
renormalize_embs: bool = False,
205+
batch_size: int = 8,
206+
verbose: bool = True,
207+
):
183208
"""Cache linear coefs for ngrams into a dictionary self.coefs_dict_
184209
If it already exists, only add linear coefs for new ngrams
185210
@@ -194,18 +219,16 @@ def cache_linear_coefs(self, X: ArrayLike, model=None,
194219
ngrams_list = self._get_ngrams_list(X)
195220

196221
# dont recompute ngrams we already know
197-
if hasattr(self, 'coefs_dict_'):
222+
if hasattr(self, "coefs_dict_"):
198223
coefs_dict_old = self.coefs_dict_
199224
else:
200225
coefs_dict_old = {}
201-
ngrams_list = [ngram for ngram in ngrams_list
202-
if not ngram in coefs_dict_old]
226+
ngrams_list = [ngram for ngram in ngrams_list if not ngram in coefs_dict_old]
203227
if len(ngrams_list) == 0 and verbose:
204-
print('\tNothing to update!')
228+
print("\tNothing to update!")
205229
return
206230

207-
embs = self._get_embs(ngrams_list, model,
208-
tokenizer_embeddings, batch_size)
231+
embs = self._get_embs(ngrams_list, model, tokenizer_embeddings, batch_size)
209232
if renormalize_embs:
210233
embs = StandardScaler().fit_transform(embs)
211234
elif self.normalize_embs:
@@ -216,57 +239,47 @@ def cache_linear_coefs(self, X: ArrayLike, model=None,
216239
linear_coef = embs @ coef_embs
217240
self.coefs_dict_ = {
218241
**coefs_dict_old,
219-
**{ngrams_list[i]: linear_coef[i]
220-
for i in range(len(ngrams_list))}
242+
**{ngrams_list[i]: linear_coef[i] for i in range(len(ngrams_list))},
221243
}
222244
if verbose:
223-
print('\tAfter caching, coefs_dict_ len', len(self.coefs_dict_))
245+
print("\tAfter caching, coefs_dict_ len", len(self.coefs_dict_))
224246

225247
def _get_embs(self, ngrams_list, model, tokenizer_embeddings, batch_size):
226-
"""Get embeddings for a list of ngrams (not summed!)
227-
"""
248+
"""Get embeddings for a list of ngrams (not summed!)"""
228249
embs = []
229-
if self.checkpoint.startswith('hkunlp/instructor-xl'):
250+
if self.checkpoint.startswith("hkunlp/instructor-xl"):
230251
# INSTRUCTION = "Represent the short phrase for sentiment classification: "
231252
# embs = model.encode([[INSTRUCTION, x_i] for x_i in ngrams_list], batch_size=32)
232253
embs = []
233254
batch_size = 32
234255
for i in tqdm(range(0, len(ngrams_list), batch_size)):
235256
# ngram = ngrams_list[i]
236257
# embs.append(model.encode([[INSTRUCTION, ngram]])[0])
237-
ngram_batch = ngrams_list[i: i + batch_size]
258+
ngram_batch = ngrams_list[i : i + batch_size]
238259
embs_batch = model.encode(
239-
[[self.instructor_prompt, ngram] for ngram in ngram_batch])
260+
[[self.instructor_prompt, ngram] for ngram in ngram_batch]
261+
)
240262
embs.append(embs_batch)
241263
embs = np.vstack(embs).squeeze()
242-
else:
243-
for i in tqdm(range(len(ngrams_list))):
244-
tokens = tokenizer_embeddings(
245-
[ngrams_list[i]], padding=True, truncation=True, return_tensors="pt")
246-
247-
tokens = Dataset.from_dict(tokens).with_format("torch")
248-
249-
embeddings = []
250-
for batch in DataLoader(tokens, batch_size=batch_size, shuffle=False):
251-
batch = {k: v.to(model.device) for k, v in batch.items()}
252-
253-
with torch.no_grad():
254-
output = model(**batch)
255-
torch.cuda.empty_cache()
256-
257-
emb = output[self.layer].cpu().detach().numpy()
258-
259-
# emb = np.array(emb, dtype="object")
260-
if len(emb.shape) == 3: # includes seq_len
261-
emb = emb.mean(axis=1)
262-
embeddings.append(emb)
263264

264-
embeddings = np.concatenate(embeddings)
265-
266-
embs.append(embeddings)
265+
else:
266+
embs = []
267+
for x in tqdm(ngrams_list):
268+
emb = imodelsx.auggam.embed.embed_and_sum_function(
269+
x,
270+
model=model,
271+
ngrams=None,
272+
tokenizer_embeddings=tokenizer_embeddings,
273+
tokenizer_ngrams=self.tokenizer_ngrams,
274+
checkpoint=self.checkpoint,
275+
layer=self.layer,
276+
# only return a single embedding
277+
fit_with_ngram_decomposition=False,
278+
sum_embeddings=False,
279+
)
280+
embs.append(emb["embs"])
281+
embs = np.array(embs).squeeze() # num_examples x embedding_size
267282

268-
embs = np.concatenate(embs)
269-
embs = embs.squeeze()
270283
return embs
271284

272285
"""
@@ -288,15 +301,15 @@ def _get_ngrams_list(self, X):
288301
ngrams=self.ngrams,
289302
tokenizer_ngrams=self.tokenizer_ngrams,
290303
all_ngrams=self.all_ngrams,
291-
min_frequency=self.min_frequency
304+
min_frequency=self.min_frequency,
292305
)
293306
all_ngrams |= set(seqs)
294307
return sorted(list(all_ngrams))
295308

296309
def predict(self, X, warn=True):
297-
'''For regression returns continuous output.
310+
"""For regression returns continuous output.
298311
For classification, returns discrete output.
299-
'''
312+
"""
300313
check_is_fitted(self)
301314
preds = self._predict_cached(X, warn=warn)
302315
if isinstance(self, RegressorMixin):
@@ -309,21 +322,18 @@ def predict(self, X, warn=True):
309322

310323
def predict_proba(self, X, warn=True):
311324
if not isinstance(self, ClassifierMixin):
312-
raise Exception(
313-
"predict_proba only available for EmbGAMClassifier")
325+
raise Exception("predict_proba only available for EmbGAMClassifier")
314326
check_is_fitted(self)
315327
preds = self._predict_cached(X, warn=warn)
316328
if preds.ndim > 1: # multiclass classification
317329
logits = preds
318330
else:
319-
logits = np.vstack(
320-
(1 - preds, preds)).transpose()
331+
logits = np.vstack((1 - preds, preds)).transpose()
321332
return softmax(logits, axis=1)
322333

323334
def _predict_cached(self, X, warn):
324-
"""Predict only the cached coefs in self.coefs_dict_
325-
"""
326-
assert hasattr(self, 'coefs_dict_'), 'coefs are not cached!'
335+
"""Predict only the cached coefs in self.coefs_dict_"""
336+
assert hasattr(self, "coefs_dict_"), "coefs are not cached!"
327337
preds = []
328338
n_unseen_ngrams = 0
329339
n_classes = len(self.classes_)
@@ -346,9 +356,10 @@ def _predict_cached(self, X, warn):
346356
preds.append(pred)
347357
if n_unseen_ngrams > 0 and warn:
348358
warnings.warn(
349-
f'Saw an unseen ungram {n_unseen_ngrams} times. \
359+
f"Saw an unseen ungram {n_unseen_ngrams} times. \
350360
For better performance, call cache_linear_coefs on the test dataset \
351-
before calling predict.')
361+
before calling predict."
362+
)
352363
return np.array(preds)
353364

354365

0 commit comments

Comments
 (0)