Hello,
I would like to try out word-level knowledge distillation (https://arxiv.org/abs/1606.07947) and for this I need the probabilities of all tokens (or at least the top-k ones) at each decoding step. I see that it's already possible to print the probability of the generated token with the with_token_level flag but it is not clear to me how to modify the code to get the probabilities of the top-k tokens at each step.
Any help is appreciated,
Thanks,
Z