-
Notifications
You must be signed in to change notification settings - Fork 138
Open
Description
neural_sp/neural_sp/models/seq2seq/decoders/las.py
Lines 535 to 548 in 2b10b9c
| for b in range(bs): | |
| nbest_hyps_id_b = [np.fromiter(y, dtype=np.int64) for y in nbest_hyps_id[b]] | |
| nbest_hyps_id_batch += nbest_hyps_id_b | |
| scores_b = np2tensor(np.array(scores[b], dtype=np.float32), eouts.device) | |
| probs_b_norm = torch.softmax(scaling_factor * scores_b, dim=-1) # `[nbest]` | |
| wers_b = np2tensor(np.array([ | |
| compute_wer(ref=idx2token(ys_ref[b]).split(' '), | |
| hyp=idx2token(nbest_hyps_id_b[n]).split(' '))[0] / 100 | |
| for n in range(nbest)], dtype=np.float32), eouts.device) | |
| exp_wer_b = (probs_b_norm * wers_b).sum() | |
| grad_list += [(probs_b_norm * (wers_b - exp_wer_b)).sum()] | |
| exp_wer += exp_wer_b | |
| exp_wer /= bs | |
I don't know much about mbr, according to these lines, it looks like a mWER loss and gradient to me
Metadata
Metadata
Assignees
Labels
No labels