From 74f61b145fa778d8e4ecebcef30988e858283650 Mon Sep 17 00:00:00 2001 From: Michal Mucha <7082264+MichaMucha@users.noreply.github.com> Date: Sun, 2 Apr 2023 23:50:03 +0100 Subject: [PATCH] prevent prompt tensors from accumulating in GPU --- inference/bot.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/inference/bot.py b/inference/bot.py index 00a4f05..4d952aa 100644 --- a/inference/bot.py +++ b/inference/bot.py @@ -85,23 +85,28 @@ def __init__(self, model_name, gpu_id, max_memory): def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None): stop_criteria = StopWordsCriteria(self._tokenizer, [self.human_id], stream_callback) - inputs = ( - self._tokenizer(prompt, return_tensors='pt') - .to(self._model.device) - ) - outputs = self._model.generate( - **inputs, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - pad_token_id=self._tokenizer.eos_token_id, - stopping_criteria=StoppingCriteriaList([stop_criteria]), - ) - output = self._tokenizer.batch_decode(outputs)[0] - - # remove the context from the output - output = output[len(prompt):] + try: + inputs = ( + self._tokenizer(prompt, return_tensors='pt') + .to(self._model.device) + ) + outputs = self._model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + pad_token_id=self._tokenizer.eos_token_id, + stopping_criteria=StoppingCriteriaList([stop_criteria]), + ) + del inputs + output = self._tokenizer.batch_decode(outputs)[0] + del outputs + + # remove the context from the output + output = output[len(prompt):] + finally: + torch.cuda.empty_cache() return output