22
22
import os .path
23
23
import warnings
24
24
import pickle as pkl
25
+ from os .path import join
25
26
import torch
26
- from torch .utils .data import DataLoader
27
- from datasets import Dataset
27
+ from transformers import LlamaModel , LlamaTokenizer
28
28
from sklearn .exceptions import ConvergenceWarning
29
- device = 'cuda' if torch .cuda .is_available () else 'cpu'
29
+
30
+ device = "cuda" if torch .cuda .is_available () else "cpu"
30
31
31
32
32
33
class AugGAM (BaseEstimator ):
33
34
def __init__ (
34
35
self ,
35
- checkpoint : str = ' bert-base-uncased' ,
36
- layer : str = ' last_hidden_state' ,
36
+ checkpoint : str = " bert-base-uncased" ,
37
+ layer : str = " last_hidden_state" ,
37
38
ngrams : int = 2 ,
38
39
all_ngrams : bool = False ,
39
40
min_frequency : int = 1 ,
@@ -43,7 +44,7 @@ def __init__(
43
44
fit_with_ngram_decomposition = True ,
44
45
instructor_prompt = None ,
45
46
):
46
- ''' AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
47
+ """ AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
47
48
48
49
Parameters
49
50
----------
@@ -69,7 +70,7 @@ def __init__(
69
70
Usually, setting this to False will considerably impede performance
70
71
instructor_prompt
71
72
if not None, use instructor-xl with this prompt
72
- '''
73
+ """
73
74
self .checkpoint = checkpoint
74
75
self .ngrams = ngrams
75
76
if tokenizer_ngrams == None :
@@ -84,12 +85,16 @@ def __init__(
84
85
self .fit_with_ngram_decomposition = fit_with_ngram_decomposition
85
86
self .instructor_prompt = instructor_prompt
86
87
87
- def fit (self , X : ArrayLike , y : ArrayLike , verbose = True ,
88
- cache_linear_coefs : bool = True ,
89
- cache_embs_dir : str = None ,
90
- batch_size : int = 8
91
- ):
92
- '''Extract embeddings then fit linear model
88
+ def fit (
89
+ self ,
90
+ X : ArrayLike ,
91
+ y : ArrayLike ,
92
+ verbose = True ,
93
+ cache_linear_coefs : bool = True ,
94
+ cache_embs_dir : str = None ,
95
+ batch_size : int = 8 ,
96
+ ):
97
+ """Extract embeddings then fit linear model
93
98
94
99
Parameters
95
100
----------
@@ -101,7 +106,7 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
101
106
if not None, directory to save embeddings into
102
107
batch_size, optional
103
108
if not None, batch size to pass while calculating embeddings
104
- '''
109
+ """
105
110
106
111
# metadata
107
112
if isinstance (self , ClassifierMixin ):
@@ -111,25 +116,29 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
111
116
112
117
# set up model
113
118
if verbose :
114
- print (' initializing model...' )
119
+ print (" initializing model..." )
115
120
model , tokenizer_embeddings = self ._get_model_and_tokenizer ()
116
121
117
122
# get embs
118
123
if verbose :
119
- print ('calculating embeddings...' )
120
- embs = self ._get_embs_summed (
121
- X , model , tokenizer_embeddings , batch_size )
124
+ print ("calculating embeddings..." )
125
+ if cache_embs_dir is not None and os .path .exists (
126
+ os .path .join (cache_embs_dir , "embs.pkl" )
127
+ ):
128
+ embs = pkl .load (open (os .path .join (cache_embs_dir , "embs.pkl" ), "rb" ))
129
+ else :
130
+ embs = self ._get_embs_summed (X , model , tokenizer_embeddings , batch_size )
131
+ if cache_embs_dir is not None :
132
+ os .makedirs (cache_embs_dir , exist_ok = True )
133
+ pkl .dump (embs , open (os .path .join (cache_embs_dir , "embs.pkl" ), "wb" ))
122
134
if self .normalize_embs :
123
135
self .normalizer = StandardScaler ()
124
136
embs = self .normalizer .fit_transform (embs )
125
- if cache_embs_dir is not None :
126
- os .makedirs (cache_embs_dir , exist_ok = True )
127
- pkl .dump (embs , open (os .path .join (cache_embs_dir , 'embs.pkl' ), 'wb' ))
128
137
129
138
# train linear
130
139
warnings .filterwarnings ("ignore" , category = ConvergenceWarning )
131
140
if verbose :
132
- print (' training linear model...' )
141
+ print (" training linear model..." )
133
142
if isinstance (self , ClassifierMixin ):
134
143
self .linear = LogisticRegressionCV ()
135
144
elif isinstance (self , RegressorMixin ):
@@ -139,7 +148,7 @@ def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
139
148
# cache linear coefs
140
149
if cache_linear_coefs :
141
150
if verbose :
142
- print (' caching linear coefs...' )
151
+ print (" caching linear coefs..." )
143
152
self .cache_linear_coefs (X , model , tokenizer_embeddings )
144
153
145
154
return self
@@ -158,28 +167,44 @@ def _get_embs_summed(self, X, model, tokenizer_embeddings, batch_size):
158
167
all_ngrams = self .all_ngrams ,
159
168
fit_with_ngram_decomposition = self .fit_with_ngram_decomposition ,
160
169
instructor_prompt = self .instructor_prompt ,
161
- batch_size = batch_size
170
+ batch_size = batch_size ,
162
171
)
163
- embs .append (emb [' embs' ])
172
+ embs .append (emb [" embs" ])
164
173
return np .array (embs ).squeeze () # num_examples x embedding_size
165
174
166
175
def _get_model_and_tokenizer (self ):
167
- if self .checkpoint .startswith (' hkunlp/instructor-xl' ):
176
+ if self .checkpoint .startswith (" hkunlp/instructor-xl" ):
168
177
from InstructorEmbedding import INSTRUCTOR
178
+
169
179
model = INSTRUCTOR (self .checkpoint ).to (device )
170
180
tokenizer_embeddings = None
181
+ elif 'llama' in self .checkpoint :
182
+ # path to extracted llama weights
183
+ LLAMA_DIR = join (os .path .expanduser ("~" ), "llama" )
184
+ tokenizer_embeddings = transformers .LlamaTokenizer .from_pretrained (
185
+ join (LLAMA_DIR , self .checkpoint )
186
+ )
187
+ model = transformers .LlamaModel .from_pretrained (
188
+ join (LLAMA_DIR , self .checkpoint ),
189
+ device_map = "auto" ,
190
+ torch_dtype = torch .float16 ,
191
+ )
171
192
else :
172
- model = transformers .AutoModel .from_pretrained (
173
- self .checkpoint ).to (device )
193
+ model = transformers .AutoModel .from_pretrained (self .checkpoint ).to (device )
174
194
tokenizer_embeddings = transformers .AutoTokenizer .from_pretrained (
175
- self .checkpoint )
195
+ self .checkpoint
196
+ )
176
197
return model , tokenizer_embeddings
177
198
178
- def cache_linear_coefs (self , X : ArrayLike , model = None ,
179
- tokenizer_embeddings = None ,
180
- renormalize_embs : bool = False ,
181
- batch_size : int = 8 ,
182
- verbose : bool = True ):
199
+ def cache_linear_coefs (
200
+ self ,
201
+ X : ArrayLike ,
202
+ model = None ,
203
+ tokenizer_embeddings = None ,
204
+ renormalize_embs : bool = False ,
205
+ batch_size : int = 8 ,
206
+ verbose : bool = True ,
207
+ ):
183
208
"""Cache linear coefs for ngrams into a dictionary self.coefs_dict_
184
209
If it already exists, only add linear coefs for new ngrams
185
210
@@ -194,18 +219,16 @@ def cache_linear_coefs(self, X: ArrayLike, model=None,
194
219
ngrams_list = self ._get_ngrams_list (X )
195
220
196
221
# dont recompute ngrams we already know
197
- if hasattr (self , ' coefs_dict_' ):
222
+ if hasattr (self , " coefs_dict_" ):
198
223
coefs_dict_old = self .coefs_dict_
199
224
else :
200
225
coefs_dict_old = {}
201
- ngrams_list = [ngram for ngram in ngrams_list
202
- if not ngram in coefs_dict_old ]
226
+ ngrams_list = [ngram for ngram in ngrams_list if not ngram in coefs_dict_old ]
203
227
if len (ngrams_list ) == 0 and verbose :
204
- print (' \t Nothing to update!' )
228
+ print (" \t Nothing to update!" )
205
229
return
206
230
207
- embs = self ._get_embs (ngrams_list , model ,
208
- tokenizer_embeddings , batch_size )
231
+ embs = self ._get_embs (ngrams_list , model , tokenizer_embeddings , batch_size )
209
232
if renormalize_embs :
210
233
embs = StandardScaler ().fit_transform (embs )
211
234
elif self .normalize_embs :
@@ -216,57 +239,47 @@ def cache_linear_coefs(self, X: ArrayLike, model=None,
216
239
linear_coef = embs @ coef_embs
217
240
self .coefs_dict_ = {
218
241
** coefs_dict_old ,
219
- ** {ngrams_list [i ]: linear_coef [i ]
220
- for i in range (len (ngrams_list ))}
242
+ ** {ngrams_list [i ]: linear_coef [i ] for i in range (len (ngrams_list ))},
221
243
}
222
244
if verbose :
223
- print (' \t After caching, coefs_dict_ len' , len (self .coefs_dict_ ))
245
+ print (" \t After caching, coefs_dict_ len" , len (self .coefs_dict_ ))
224
246
225
247
def _get_embs (self , ngrams_list , model , tokenizer_embeddings , batch_size ):
226
- """Get embeddings for a list of ngrams (not summed!)
227
- """
248
+ """Get embeddings for a list of ngrams (not summed!)"""
228
249
embs = []
229
- if self .checkpoint .startswith (' hkunlp/instructor-xl' ):
250
+ if self .checkpoint .startswith (" hkunlp/instructor-xl" ):
230
251
# INSTRUCTION = "Represent the short phrase for sentiment classification: "
231
252
# embs = model.encode([[INSTRUCTION, x_i] for x_i in ngrams_list], batch_size=32)
232
253
embs = []
233
254
batch_size = 32
234
255
for i in tqdm (range (0 , len (ngrams_list ), batch_size )):
235
256
# ngram = ngrams_list[i]
236
257
# embs.append(model.encode([[INSTRUCTION, ngram]])[0])
237
- ngram_batch = ngrams_list [i : i + batch_size ]
258
+ ngram_batch = ngrams_list [i : i + batch_size ]
238
259
embs_batch = model .encode (
239
- [[self .instructor_prompt , ngram ] for ngram in ngram_batch ])
260
+ [[self .instructor_prompt , ngram ] for ngram in ngram_batch ]
261
+ )
240
262
embs .append (embs_batch )
241
263
embs = np .vstack (embs ).squeeze ()
242
- else :
243
- for i in tqdm (range (len (ngrams_list ))):
244
- tokens = tokenizer_embeddings (
245
- [ngrams_list [i ]], padding = True , truncation = True , return_tensors = "pt" )
246
-
247
- tokens = Dataset .from_dict (tokens ).with_format ("torch" )
248
-
249
- embeddings = []
250
- for batch in DataLoader (tokens , batch_size = batch_size , shuffle = False ):
251
- batch = {k : v .to (model .device ) for k , v in batch .items ()}
252
-
253
- with torch .no_grad ():
254
- output = model (** batch )
255
- torch .cuda .empty_cache ()
256
-
257
- emb = output [self .layer ].cpu ().detach ().numpy ()
258
-
259
- # emb = np.array(emb, dtype="object")
260
- if len (emb .shape ) == 3 : # includes seq_len
261
- emb = emb .mean (axis = 1 )
262
- embeddings .append (emb )
263
264
264
- embeddings = np .concatenate (embeddings )
265
-
266
- embs .append (embeddings )
265
+ else :
266
+ embs = []
267
+ for x in tqdm (ngrams_list ):
268
+ emb = imodelsx .auggam .embed .embed_and_sum_function (
269
+ x ,
270
+ model = model ,
271
+ ngrams = None ,
272
+ tokenizer_embeddings = tokenizer_embeddings ,
273
+ tokenizer_ngrams = self .tokenizer_ngrams ,
274
+ checkpoint = self .checkpoint ,
275
+ layer = self .layer ,
276
+ # only return a single embedding
277
+ fit_with_ngram_decomposition = False ,
278
+ sum_embeddings = False ,
279
+ )
280
+ embs .append (emb ["embs" ])
281
+ embs = np .array (embs ).squeeze () # num_examples x embedding_size
267
282
268
- embs = np .concatenate (embs )
269
- embs = embs .squeeze ()
270
283
return embs
271
284
272
285
"""
@@ -288,15 +301,15 @@ def _get_ngrams_list(self, X):
288
301
ngrams = self .ngrams ,
289
302
tokenizer_ngrams = self .tokenizer_ngrams ,
290
303
all_ngrams = self .all_ngrams ,
291
- min_frequency = self .min_frequency
304
+ min_frequency = self .min_frequency ,
292
305
)
293
306
all_ngrams |= set (seqs )
294
307
return sorted (list (all_ngrams ))
295
308
296
309
def predict (self , X , warn = True ):
297
- ''' For regression returns continuous output.
310
+ """ For regression returns continuous output.
298
311
For classification, returns discrete output.
299
- '''
312
+ """
300
313
check_is_fitted (self )
301
314
preds = self ._predict_cached (X , warn = warn )
302
315
if isinstance (self , RegressorMixin ):
@@ -309,21 +322,18 @@ def predict(self, X, warn=True):
309
322
310
323
def predict_proba (self , X , warn = True ):
311
324
if not isinstance (self , ClassifierMixin ):
312
- raise Exception (
313
- "predict_proba only available for EmbGAMClassifier" )
325
+ raise Exception ("predict_proba only available for EmbGAMClassifier" )
314
326
check_is_fitted (self )
315
327
preds = self ._predict_cached (X , warn = warn )
316
328
if preds .ndim > 1 : # multiclass classification
317
329
logits = preds
318
330
else :
319
- logits = np .vstack (
320
- (1 - preds , preds )).transpose ()
331
+ logits = np .vstack ((1 - preds , preds )).transpose ()
321
332
return softmax (logits , axis = 1 )
322
333
323
334
def _predict_cached (self , X , warn ):
324
- """Predict only the cached coefs in self.coefs_dict_
325
- """
326
- assert hasattr (self , 'coefs_dict_' ), 'coefs are not cached!'
335
+ """Predict only the cached coefs in self.coefs_dict_"""
336
+ assert hasattr (self , "coefs_dict_" ), "coefs are not cached!"
327
337
preds = []
328
338
n_unseen_ngrams = 0
329
339
n_classes = len (self .classes_ )
@@ -346,9 +356,10 @@ def _predict_cached(self, X, warn):
346
356
preds .append (pred )
347
357
if n_unseen_ngrams > 0 and warn :
348
358
warnings .warn (
349
- f' Saw an unseen ungram { n_unseen_ngrams } times. \
359
+ f" Saw an unseen ungram { n_unseen_ngrams } times. \
350
360
For better performance, call cache_linear_coefs on the test dataset \
351
- before calling predict.' )
361
+ before calling predict."
362
+ )
352
363
return np .array (preds )
353
364
354
365
0 commit comments