@@ -29,9 +29,14 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
29
29
< pre > < code class ="python "> from transformers import BertModel, DistilBertModel
30
30
from transformers import AutoModelForCausalLM
31
31
from os.path import join as oj
32
+ from datasets import Dataset
33
+ from tqdm import tqdm
32
34
import torch
35
+ import numpy as np
36
+ from torch.utils.data import DataLoader
33
37
import imodelsx.util
34
38
39
+
35
40
def get_model(checkpoint):
36
41
if 'distilbert' in checkpoint.lower():
37
42
model = DistilBertModel.from_pretrained(checkpoint)
@@ -79,7 +84,8 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
79
84
checkpoint: str,
80
85
dataset_key_text: str = None,
81
86
layer: str = 'last_hidden_state',
82
- padding: bool = True,
87
+ padding: str = "max_length",
88
+ batch_size: int = 8,
83
89
parsing: str = '',
84
90
nlp_chunks=None,
85
91
all_ngrams: bool = False,
@@ -134,27 +140,57 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
134
140
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
135
141
tokens = tokenizer_embeddings(seqs, padding=padding,
136
142
truncation=True, return_tensors="pt")
137
- tokens = tokens.to(model.device)
138
- output = model(**tokens)
139
- if layer == 'pooler_output':
140
- embs = output['pooler_output'].cpu().detach().numpy()
141
- elif layer == 'last_hidden_state_mean' or layer == 'last_hidden_state':
142
- embs = output['last_hidden_state'].cpu().detach().numpy()
143
- embs = embs.mean(axis=1)
143
+
144
+ embs = []
145
+
146
+ ds = Dataset.from_dict(tokens).with_format("torch")
147
+
148
+ for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
149
+ batch = {k: v.to(model.device) for k, v in batch.items()}
150
+
151
+ with torch.no_grad():
152
+ output = model(**batch)
153
+ torch.cuda.empty_cache()
154
+
155
+ if layer == 'pooler_output':
156
+ emb = output['pooler_output'].cpu().detach().numpy()
157
+ elif layer == 'last_hidden_state_mean' or layer == 'last_hidden_state':
158
+ emb = output['last_hidden_state'].cpu().detach().numpy()
159
+ emb = emb.mean(axis=1)
160
+
161
+ embs.append(emb)
162
+
163
+ embs = np.concatenate(embs)
164
+
144
165
elif 'gpt' in checkpoint.lower():
145
166
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
146
- tokens = tokens.to(model.device)
147
- output = model(**tokens)
148
-
149
- # tuple of (layer x (batch_size, seq_len, hidden_size))
150
- h = output['hidden_states']
151
- # (batch_size, seq_len, hidden_size)
152
- embs = h[0].cpu().detach().numpy()
153
- embs = embs.mean(axis=1) # (batch_size, hidden_size)
167
+
168
+ embs = []
169
+
170
+ ds = Dataset.from_dict(tokens).with_format("torch")
171
+
172
+ for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
173
+ batch = {k: v.to(model.device) for k, v in batch.items()}
174
+
175
+ with torch.no_grad():
176
+ output = model(**batch)
177
+ torch.cuda.empty_cache()
178
+
179
+ # tuple of (layer x (batch_size, seq_len, hidden_size))
180
+ h = output['hidden_states']
181
+ # (batch_size, seq_len, hidden_size)
182
+ emb = h[0].cpu().detach().numpy()
183
+ emb = emb.mean(axis=1) # (batch_size, hidden_size)
184
+
185
+ embs.append(emb)
186
+
187
+ embs = np.concatenate(embs)
188
+
154
189
elif checkpoint.startswith('hkunlp/instructor'):
155
190
if instructor_prompt is None:
156
191
instructor_prompt = "Represent the short phrase for sentiment classification: "
157
- embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
192
+ embs = model.encode([[instructor_prompt, x_i]
193
+ for x_i in seqs], batch_size=batch_size)
158
194
159
195
# sum over the embeddings
160
196
embs = embs.sum(axis=0).reshape(1, -1)
@@ -172,7 +208,7 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
172
208
< h2 class ="section-title " id ="header-functions "> Functions</ h2 >
173
209
< dl >
174
210
< dt id ="imodelsx.auggam.embed.embed_and_sum_function "> < code class ="name flex ">
175
- < span > def < span class ="ident "> embed_and_sum_function</ span > </ span > (< span > example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: bool = True , parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</ span >
211
+ < span > def < span class ="ident "> embed_and_sum_function</ span > </ span > (< span > example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: str = 'max_length', batch_size: int = 8 , parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</ span >
176
212
</ code > </ dt >
177
213
< dd >
178
214
< div class ="desc "> < p > Get summed embeddings for a single example</ p >
@@ -206,7 +242,8 @@ <h2 id="params">Params</h2>
206
242
checkpoint: str,
207
243
dataset_key_text: str = None,
208
244
layer: str = 'last_hidden_state',
209
- padding: bool = True,
245
+ padding: str = "max_length",
246
+ batch_size: int = 8,
210
247
parsing: str = '',
211
248
nlp_chunks=None,
212
249
all_ngrams: bool = False,
@@ -261,27 +298,57 @@ <h2 id="params">Params</h2>
261
298
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
262
299
tokens = tokenizer_embeddings(seqs, padding=padding,
263
300
truncation=True, return_tensors="pt")
264
- tokens = tokens.to(model.device)
265
- output = model(**tokens)
266
- if layer == 'pooler_output':
267
- embs = output['pooler_output'].cpu().detach().numpy()
268
- elif layer == 'last_hidden_state_mean' or layer == 'last_hidden_state':
269
- embs = output['last_hidden_state'].cpu().detach().numpy()
270
- embs = embs.mean(axis=1)
301
+
302
+ embs = []
303
+
304
+ ds = Dataset.from_dict(tokens).with_format("torch")
305
+
306
+ for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
307
+ batch = {k: v.to(model.device) for k, v in batch.items()}
308
+
309
+ with torch.no_grad():
310
+ output = model(**batch)
311
+ torch.cuda.empty_cache()
312
+
313
+ if layer == 'pooler_output':
314
+ emb = output['pooler_output'].cpu().detach().numpy()
315
+ elif layer == 'last_hidden_state_mean' or layer == 'last_hidden_state':
316
+ emb = output['last_hidden_state'].cpu().detach().numpy()
317
+ emb = emb.mean(axis=1)
318
+
319
+ embs.append(emb)
320
+
321
+ embs = np.concatenate(embs)
322
+
271
323
elif 'gpt' in checkpoint.lower():
272
324
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
273
- tokens = tokens.to(model.device)
274
- output = model(**tokens)
275
-
276
- # tuple of (layer x (batch_size, seq_len, hidden_size))
277
- h = output['hidden_states']
278
- # (batch_size, seq_len, hidden_size)
279
- embs = h[0].cpu().detach().numpy()
280
- embs = embs.mean(axis=1) # (batch_size, hidden_size)
325
+
326
+ embs = []
327
+
328
+ ds = Dataset.from_dict(tokens).with_format("torch")
329
+
330
+ for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
331
+ batch = {k: v.to(model.device) for k, v in batch.items()}
332
+
333
+ with torch.no_grad():
334
+ output = model(**batch)
335
+ torch.cuda.empty_cache()
336
+
337
+ # tuple of (layer x (batch_size, seq_len, hidden_size))
338
+ h = output['hidden_states']
339
+ # (batch_size, seq_len, hidden_size)
340
+ emb = h[0].cpu().detach().numpy()
341
+ emb = emb.mean(axis=1) # (batch_size, hidden_size)
342
+
343
+ embs.append(emb)
344
+
345
+ embs = np.concatenate(embs)
346
+
281
347
elif checkpoint.startswith('hkunlp/instructor'):
282
348
if instructor_prompt is None:
283
349
instructor_prompt = "Represent the short phrase for sentiment classification: "
284
- embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
350
+ embs = model.encode([[instructor_prompt, x_i]
351
+ for x_i in seqs], batch_size=batch_size)
285
352
286
353
# sum over the embeddings
287
354
embs = embs.sum(axis=0).reshape(1, -1)
0 commit comments