Skip to content

Commit 4446be4

Browse files
committed
Minor bug fixes and code cleanup
1 parent 32983bb commit 4446be4

File tree

3 files changed

+56
-49
lines changed

3 files changed

+56
-49
lines changed

model_generator.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config
1+
from transformers import GPT2LMHeadModel, GPT2Config
22

33
import torch.utils.data.dataset
44
import utils_tokenizer
@@ -16,8 +16,8 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok
1616
elif tokenizer_type == "bpecap":
1717
self.tokenizer = utils_tokenizer.BPETokenizer(bpe_model)
1818
config = GPT2Config.from_dict({"finetuning_task": None, "initializer_range": 0.02,
19-
"layer_norm_epsilon": 1e-05, "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_layer": 12, "n_positions": 1024, "num_labels": 1,
20-
"resid_pdrop": 0.1, "use_bfloat16": False, "vocab_size": self.tokenizer.vocab_size})
19+
"layer_norm_epsilon": 1e-05, "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_layer": 12, "n_positions": 1024, "num_labels": 1,
20+
"resid_pdrop": 0.1, "use_bfloat16": False, "vocab_size": self.tokenizer.vocab_size})
2121
else:
2222
print("Tokenizer unrecognized. Should be gpt2 or bpecap.")
2323
exit()
@@ -36,7 +36,7 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok
3636
self.mode = "train"
3737

3838
def reload(self, from_file):
39-
print(self.model.load_state_dict(torch.load(from_file)))
39+
print(self.model.load_state_dict(torch.load(from_file), strict=False))
4040

4141
def save(self, to_file):
4242
torch.save(self.model.state_dict(), to_file)
@@ -132,22 +132,22 @@ def decode_batch(self, bodies, special_append=None, max_output_length=100, sampl
132132
elif return_scores:
133133
return outputs, scores.tolist()
134134
else:
135-
return outputs, end_indices
135+
return outputs
136136

137137
def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=False):
138138
if self.mode != 'eval':
139139
print("BEWARE. Model is not in eval mode.")
140-
self.eval() ## << Surely you are not training with beam decode?
140+
self.eval() # << Surely you are not training with beam decode?
141141

142142
batch_size = len(bodies)
143143
N = batch_size * beam_size
144144
inputs = self.preprocess_input(bodies)
145145
next_words = torch.LongTensor([self.tokenizer.start_id] * N).to(self.device).unsqueeze(1)
146146
build_up = None
147147
scores = torch.zeros((N)).to(self.device)
148-
148+
149149
one_every_k = torch.FloatTensor([1] + [0] * (beam_size-1)).repeat(batch_size*beam_size).to(self.device)
150-
150+
151151
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
152152
_, input_past = self.model(input_ids=inputs, past_key_values=None)
153153
input_past = [torch.repeat_interleave(p, repeats=beam_size, dim=1) for p in input_past]
@@ -157,23 +157,23 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
157157
logits, past = self.model(input_ids=next_words, past_key_values=past)
158158
probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
159159
logprobs = torch.nn.functional.log_softmax(logits, dim=2)
160-
160+
161161
if sample:
162162
all_selects = torch.multinomial(probs, beam_size).unsqueeze(1)
163163
else:
164164
_, all_selects = torch.topk(logprobs, k=beam_size, dim=2)
165-
165+
166166
if build_up is not None:
167167
not_finished = (1-torch.any(build_up==self.tokenizer.end_id, dim=1).float()).to(self.device)
168168
else:
169-
not_finished = torch.ones_like(scores, dtype=torch.float, device=self.device)
170-
169+
not_finished = torch.ones_like(scores, dtype=torch.float, device=self.device)
170+
171171
expanded_not_finished = torch.repeat_interleave(not_finished, repeats=beam_size)
172-
172+
173173
expanded_score = torch.repeat_interleave(scores, repeats=beam_size) # This should be batch_size * beam_size²
174174
added_score = logprobs[torch.repeat_interleave(torch.arange(N), repeats=beam_size), 0, all_selects.view(-1)]
175175
expanded_score += (expanded_not_finished*added_score)
176-
176+
177177
# We don't want you to select from finished beams
178178
expanded_score -= (1-expanded_not_finished)*(1-one_every_k)*1000.0
179179

@@ -182,11 +182,11 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
182182
if build_up is None:
183183
choices = torch.arange(beam_size, device=self.device).repeat(batch_size)
184184
batched_choices = choices.view(batch_size, beam_size)
185-
185+
186186
else:
187187
_, batched_choices = torch.topk(batched_scores, k=beam_size, dim=1) # Going from k² choices per element to k choices.
188-
189-
batched_tracks = batched_choices / beam_size
188+
189+
batched_tracks = (batched_choices / beam_size).long()
190190
tracks = beam_size*torch.repeat_interleave(torch.arange(batch_size), repeats=beam_size).to(self.device) + batched_tracks.view(-1)
191191

192192
selected_scores = batched_scores[torch.repeat_interleave(torch.arange(batch_size), repeats=beam_size), batched_choices.view(-1)]
@@ -200,7 +200,7 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
200200
if build_up is not None:
201201
build_up = build_up[tracks, :]
202202
past = [p[:, tracks, :] for p in past]
203-
203+
204204
# Update the latest scores, and the current_build
205205
if build_up is None:
206206
build_up = next_words
@@ -228,7 +228,7 @@ def decode(self, bodies, max_output_length=100, max_batch_size=8, beam_size=1, r
228228
if progress:
229229
iterator = tqdm.tqdm(iterator)
230230
for i in iterator:
231-
batch_bodies = bodies[i:min(N,i+max_batch_size)]
231+
batch_bodies = bodies[i:min(N, i+max_batch_size)]
232232
with torch.no_grad():
233233
if beam_size > 1:
234234
batch_outputs = self.decode_beam_batch(batch_bodies, beam_size=beam_size, max_output_length=max_output_length, sample=sample)

train_summary_loop.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from torch.utils.data import DataLoader, RandomSampler
2-
import torch, os, sys, time, argparse, numpy as np
32
from utils_dataset import SQLDataset, HDF5Dataset
3+
import torch, os, time, argparse, numpy as np
44
from transformers.optimization import AdamW
55
from model_generator import GeneTransformer
6-
from datetime import datetime, timedelta
7-
from utils_logplot import LogPlot
86
import utils_misc, utils_tokenizer
7+
from utils_logplot import LogPlot
8+
from datetime import datetime
99

1010
from model_coverage import KeywordCoverage
1111
from model_guardrails import PatternPenalty, LengthPenalty, RepeatPenalty
@@ -17,7 +17,6 @@
1717
parser.add_argument("--experiment", type=str, required=True, help="Experiment name. Will be used to save a model file and a log file.")
1818
parser.add_argument("--dataset_file", type=str, required=True, help="Which dataset file to use. Can be full path or the root folder will be attached.")
1919

20-
parser.add_argument("--root_folder", type=str, default="/home/"+user+"/")
2120
parser.add_argument("--train_batch_size", type=int, default=5, help="Training batch size.")
2221
parser.add_argument("--n_epochs", type=int, default=3, help="Number of epochs to run over the data.")
2322
parser.add_argument("--optim_every", type=int, default=4, help="Optimize every x backprops. A multiplier to the true batch size.")
@@ -34,8 +33,8 @@
3433
os.environ["CUDA_VISIBLE_DEVICES"] = ""+str(freer_gpu)
3534
args.experiment += "_"+freer_gpu
3635

37-
models_folder = "/home/ubuntu/models/"
38-
log_folder = "/home/ubuntu/logs/"
36+
models_folder = "/home/phillab/models/"
37+
log_folder = "/home/phillab/logs/"
3938

4039
summarizer_model_start = os.path.join(models_folder, "gpt2_copier23.bin")
4140

@@ -65,6 +64,7 @@ def collate_func(inps):
6564
else:
6665
return [inp[0].decode() for inp in inps]
6766

67+
6868
param_optimizer = list(summarizer.model.named_parameters())
6969
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
7070
optimizer_grouped_parameters = [
@@ -88,9 +88,9 @@ def collate_func(inps):
8888

8989
print("Loading scorers")
9090

91-
coverage_model_file = os.path.join(models_folder, "bert_coverage.bin")
91+
coverage_model_file = os.path.join(models_folder, "bert_coverage_google_cnndm_length15_1.bin")
9292
coverage_keyword_model_file = os.path.join(models_folder, "keyword_extractor.joblib")
93-
fluency_news_model_file = os.path.join(models_folder, "fluency_news_bs32.bin")
93+
fluency_news_model_file = os.path.join(models_folder, "news_gpt2_bs32.bin")
9494

9595
scorers = [{"name": "coverage", "importance": 10.0, "sign": 1.0, "model": KeywordCoverage(args.device, keyword_model_file=coverage_keyword_model_file, model_file=coverage_model_file)},
9696
{"name": "fluency", "importance": 2.0, "sign": 1.0, "model": GeneTransformer(max_output_length=args.max_output_length, device=args.device, starter_model=fluency_news_model_file)},
@@ -102,6 +102,7 @@ def collate_func(inps):
102102
def background_tokenizer(bodies, out_queue):
103103
out_queue.put([bert_tokenizer.encode(body) for body in bodies])
104104

105+
105106
my_queue = queue.Queue()
106107
print("Started training")
107108

@@ -116,7 +117,7 @@ def background_tokenizer(bodies, out_queue):
116117
dataloader = DataLoader(dataset=dataset, batch_size=args.train_batch_size, sampler=RandomSampler(dataset), drop_last=True, collate_fn=collate_func)
117118

118119
for epi in range(n_epochs):
119-
print("=================== EPOCH",epi, "===================")
120+
print("=================== EPOCH", epi, "===================")
120121
for ib, documents in enumerate(dataloader):
121122
Timer = {}
122123

@@ -126,7 +127,7 @@ def background_tokenizer(bodies, out_queue):
126127
bodies = [" ".join(doc.split(" ")[:300]) for doc in documents]
127128

128129
# We run tokenization in the background, as it is BERT tokenization only used after the summarizer has run. Saves about 5% of time.
129-
thread1 = threading.Thread(target = background_tokenizer, args = (bodies, my_queue))
130+
thread1 = threading.Thread(target=background_tokenizer, args=(bodies, my_queue))
130131
# bodies_bert_tokenized = [bert_tokenizer.enncode(body) for body in bodies] # This is the not background version
131132
thread1.start()
132133

@@ -159,11 +160,11 @@ def background_tokenizer(bodies, out_queue):
159160
sampled_scores = torch.FloatTensor(sampled_scores).to(args.device)
160161

161162
argmax_scores, _ = scorer['model'].score(argmax_summaries, bodies, bodies_tokenized=bodies_bert_tokenized, extra=extra, lengths=argmax_end_idxs)
162-
argmax_scores = torch.FloatTensor(argmax_scores).to(args.device)
163+
argmax_scores = torch.FloatTensor(argmax_scores).to(args.device)
163164

164165
Timer["scores_"+scorer['name']] = time.time()-T
165166
total_sampled_scores += (scorer['sign'])*(scorer['importance'])*sampled_scores
166-
total_argmax_scores += (scorer['sign'])*(scorer['importance'])*argmax_scores
167+
total_argmax_scores += (scorer['sign'])*(scorer['importance'])*argmax_scores
167168
log_obj[scorer['name']+"_score"] = sampled_scores.mean().item()
168169
scores_track[scorer['name']+"_scores"] = sampled_scores
169170

@@ -180,7 +181,7 @@ def background_tokenizer(bodies, out_queue):
180181
T6 = time.time()
181182
Timer['backward'] = T6-T5
182183

183-
if ib%args.optim_every == 0:
184+
if ib % args.optim_every == 0:
184185
optimizer.step()
185186
optimizer.zero_grad()
186187

@@ -220,7 +221,7 @@ def background_tokenizer(bodies, out_queue):
220221

221222
if ckpt_every > 0 and len(total_score_history) > ckpt_lookback:
222223
current_score = np.mean(total_score_history[-ckpt_lookback:])
223-
224+
224225
if time.time()-time_ckpt > ckpt_every:
225226
revert_ckpt = best_ckpt_score is not None and current_score < min(1.2*best_ckpt_score, 0.8*best_ckpt_score) # Could be negative or positive
226227
print("================================== CKPT TIME, "+str(datetime.now())+" =================================")
@@ -232,7 +233,7 @@ def background_tokenizer(bodies, out_queue):
232233
optimizer.load_state_dict(torch.load(ckpt_optimizer_file))
233234
time_ckpt = time.time()
234235
print("==============================================================================")
235-
236+
236237
if best_ckpt_score is None or current_score > best_ckpt_score:
237238
print("[CKPT] Saved new best at: %.3f %s" % (current_score, "["+str(datetime.now())+"]"))
238239
best_ckpt_score = current_score

utils_tokenizer.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from transformers.tokenization_gpt2 import GPT2Tokenizer as GPT2Tok
2-
from transformers.tokenization_bert import BertTokenizer as BertTok
1+
from transformers import GPT2Tokenizer as GPT2Tok
2+
from transformers import BertTokenizer as BertTok
33
import sentencepiece as spm
4-
import nltk
54

65
class Capita:
76
def forward(self, text):
@@ -26,18 +25,25 @@ def forward(self, text):
2625
def backward(self, text):
2726
words = text.split(" ")
2827
final_words = []
29-
all_caps = False; capitalized = False
28+
all_caps = False
29+
capitalized = False
3030
for w in words:
31-
if w == "⇧": all_caps = True
32-
elif w == "↑": capitalized = True
31+
if w == "⇧":
32+
all_caps = True
33+
elif w == "↑":
34+
capitalized = True
3335
else:
3436
final_word = w
35-
if all_caps: final_word = final_word.upper()
37+
if all_caps:
38+
final_word = final_word.upper()
3639
elif capitalized:
37-
if len(final_word) <= 1: final_word = final_word.upper()
38-
else: final_word = final_word[0].upper()+final_word[1:]
40+
if len(final_word) <= 1:
41+
final_word = final_word.upper()
42+
else:
43+
final_word = final_word[0].upper()+final_word[1:]
3944
final_words.append(final_word)
40-
all_caps = False; capitalized = False
45+
all_caps = False
46+
capitalized = False
4147
return " ".join(final_words)
4248

4349
class BPETokenizer:
@@ -53,7 +59,7 @@ def __init__(self, bpe_model, use_capita=True):
5359

5460
if self.use_capita:
5561
self.cpt = Capita()
56-
62+
5763
def tokenize(self, text):
5864
if len(text) == 0:
5965
return []
@@ -67,12 +73,12 @@ def tokenize(self, text):
6773
if tokens[0] == "▁":
6874
tokens = tokens[1:]
6975
return tokens
70-
76+
7177
def encode(self, text):
7278
tokens = self.tokenize(text)
7379
token_ids = [self.sp.piece_to_id(w) for w in tokens]
7480
return token_ids
75-
81+
7682
def decode(self, token_ids):
7783
text = self.sp.decode_ids(token_ids).replace("⇧", " ⇧").replace("↑", " ↑")
7884
if self.use_capita:
@@ -108,8 +114,8 @@ def __init__(self):
108114

109115
self.pad_id = 0
110116
self.start_id = self.tokenizer.encode(self.start_tok)[0]
111-
self.end_id = self.tokenizer.encode(self.end_tok)[0]
112-
self.vocab_size = self.tokenizer.vocab_size
117+
self.end_id = self.tokenizer.encode(self.end_tok)[0]
118+
self.vocab_size = self.tokenizer.vocab_size
113119

114120
def tokenize(self, text):
115121
return self.tokenizer.tokenize(text)

0 commit comments

Comments
 (0)