diff --git a/inference/bot.py b/inference/bot.py index 00a4f05..4d952aa 100644 --- a/inference/bot.py +++ b/inference/bot.py @@ -85,23 +85,28 @@ def __init__(self, model_name, gpu_id, max_memory): def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None): stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback) - inputs = ( - self._tokenizer(prompt, return_tensors='pt') - .to(self._model.device) - ) - outputs = self._model.generate( - **inputs, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - pad_token_id=self._tokenizer.eos_token_id, - stopping_criteria=StoppingCriteriaList([stop_criteria]), - ) - output = self._tokenizer.batch_decode(outputs)[0] - - # remove the context from the output - output = output[len(prompt):] + try: + inputs = ( + self._tokenizer(prompt, return_tensors='pt') + .to(self._model.device) + ) + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + pad_token_id=self._tokenizer.eos_token_id, + stopping_criteria=StoppingCriteriaList([stop_criteria]), + ) + del inputs + output = self._tokenizer.batch_decode(outputs)[0] + del outputs + + # remove the context from the output + output = output[len(prompt):] + finally: + torch.cuda.empty_cache() return output