Skip to content

quetion about the loss and grad of "mbr" #322

@Cescfangs

Description

@Cescfangs

for b in range(bs):
nbest_hyps_id_b = [np.fromiter(y, dtype=np.int64) for y in nbest_hyps_id[b]]
nbest_hyps_id_batch += nbest_hyps_id_b
scores_b = np2tensor(np.array(scores[b], dtype=np.float32), eouts.device)
probs_b_norm = torch.softmax(scaling_factor * scores_b, dim=-1) # `[nbest]`
wers_b = np2tensor(np.array([
compute_wer(ref=idx2token(ys_ref[b]).split(' '),
hyp=idx2token(nbest_hyps_id_b[n]).split(' '))[0] / 100
for n in range(nbest)], dtype=np.float32), eouts.device)
exp_wer_b = (probs_b_norm * wers_b).sum()
grad_list += [(probs_b_norm * (wers_b - exp_wer_b)).sum()]
exp_wer += exp_wer_b
exp_wer /= bs

I don't know much about mbr, according to these lines, it looks like a mWER loss and gradient to me

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions