Open
Description
Hello, very impressive work here!
I am trying to understand the inference process, and my understanding is that we can refine previous input tokens (since rk, the current prediction, attends to all ri for i =1,..,k-1). How can I modify the inference process to refine previous tokens? Do I need to disable caching and pass the entire context to the transformer at each iteration?
Thank you for your help!