From 4df4657e8307ab5d588b9c1b60d54d975c29cdec Mon Sep 17 00:00:00 2001 From: Karol Kontny Date: Fri, 2 Aug 2024 12:49:56 +0200 Subject: [PATCH] Fixing mixtral runner --- .../text_generation/mixtral/run.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/natural_language_processing/text_generation/mixtral/run.py b/natural_language_processing/text_generation/mixtral/run.py index 5f7e2eb1..6dbfaf30 100644 --- a/natural_language_processing/text_generation/mixtral/run.py +++ b/natural_language_processing/text_generation/mixtral/run.py @@ -18,11 +18,12 @@ sys.exit(1) -def run_pytorch(num_runs, timeout, dataset_path, disable_jit_freeze=False, **kwargs): +def run_pytorch(num_runs, timeout, dataset_path, use_torch_fp16=False): from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from utils.nlp.alpaca_instruct import AlpacaInstruct from utils.pytorch import PyTorchRunnerV2, apply_compile from utils.benchmark import run_model + import torch def run_single_pass(pytorch_runner, dataset): input_array = [{"role": "user", "content": dataset.get_input_string()}] @@ -33,8 +34,12 @@ def run_single_pass(pytorch_runner, dataset): response = decode(outputs[:, inputs.shape[1]:])[0] dataset.submit_prediction(response) + # This is needed for TorchDynamo to correctly capture where op + torch._dynamo.config.capture_dynamic_output_shape_ops = True model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1") model.eval() + if use_torch_fp16: + model = model.half() model.forward = apply_compile(model.forward) tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1") @@ -51,8 +56,12 @@ def run_single_pass(pytorch_runner, dataset): return run_model(run_single_pass, runner, dataset, 1, num_runs, timeout) -def run_pytorch_fp32(num_runs, timeout, dataset_path, disable_jit_freeze=False, **kwargs): - return run_pytorch(num_runs, timeout, dataset_path, disable_jit_freeze, **kwargs) +def run_pytorch_fp32(num_runs, timeout, dataset_path, **kwargs): + return run_pytorch(num_runs, timeout, dataset_path) + + +def run_pytorch_fp16(num_runs, timeout, dataset_path, **kwargs): + return run_pytorch(num_runs, timeout, dataset_path, use_torch_fp16=True, **kwargs) def main():