[코드 공유] Beam Search 구현 #53
shjas94
started this conversation in
Show and tell
Replies: 1 comment
-
def forward(
self, src, text, is_train=True, batch_max_length=50, teacher_forcing_ratio=1.0
):
if is_train and random.random() < teacher_forcing_ratio:
# teacher forcing시에
tgt = self.text_embedding(text)
tgt = self.pos_encoder(tgt)
tgt_mask = self.pad_mask(text) | self.order_mask(text.size(1))
for layer in self.attention_layers:
tgt = layer(tgt, None, src, tgt_mask)
out = self.generator(tgt)
else:
print(src.shape)
num_steps = batch_max_length - 1
temp_tar = [
[torch.LongTensor(src.size(0)).fill_(self.st_id).unsqueeze(1).to(device), torch.LongTensor(src.size(0)).fill_(0).to(device), [None] * self.layer_num, []]]
k = self.k
for t in range(num_steps):
new_tar = []
for i, tar in enumerate(temp_tar):
target = tar[0][:,-1].unsqueeze(1).to(device)
pre_prob = tar[1]
pre_feature = tar[2]
temp_out = tar[3]
tgt = self.text_embedding(target)
tgt = self.pos_encoder(tgt, point=t)
tgt_mask = self.order_mask(t + 1)
tgt_mask = tgt_mask[:, -1].unsqueeze(1) # [1, (l+1)]
for l, layer in enumerate(self.attention_layers):
tgt = layer(tgt, pre_feature[l], src, tgt_mask)
pre_feature[l] = (
tgt if pre_feature[l] == None else torch.cat([pre_feature[l], tgt], 1)
)
_out = self.generator(tgt)
temp_out.append(_out)
prob = torch.topk(_out[:, -1:, :], k=k, dim=-1)[0].squeeze().transpose(0,1)
idx = torch.topk(_out[:, -1:, :], k=k, dim=-1)[1].squeeze().transpose(0,1)
for i in range(len(idx)):
new_tar.append([torch.stack([tar[0].squeeze().to(device), idx[i].to(device)],dim=1), (prob[i] + pre_prob)/2, pre_feature, temp_out])
# 이제 new_tar안에는 [seq,확률] 들이 저장되어있다. 확률값들을 비교하여 2개만 남긴다음에 나머지 삭제
sorted_tar = sorted(new_tar, key=lambda x: x[1],reverse=True)
temp_tar = [sorted_tar[:k]]
out = sorted(temp_tar,key=lambda x:x[1],reverse=True)[0][-1] # [b, max length, 1, class length]
out = torch.stack(out,dim=1).to(device)
out = out.squeeze(2)
return outbatch를 해결하지 않은 초기 prototype |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
빔서치 구현입니다. TransformerDecoder class의 forward함수만 요걸로 교체해주시면 됩니다. 혹시 제출 파일 있는분들은 적용해서 제출하고 결과 공유해주시면 감사하겠습니다.
Beta Was this translation helpful? Give feedback.
All reactions