We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6a4c991 commit a05bf7fCopy full SHA for a05bf7f
tests/spyre_util.py
@@ -10,6 +10,7 @@
10
import openai
11
import pytest
12
import requests
13
+import torch
14
from sentence_transformers import SentenceTransformer, util
15
from transformers import AutoModelForCausalLM, AutoTokenizer
16
from vllm import LLM, SamplingParams
@@ -228,7 +229,8 @@ def generate_hf_output(
228
229
if not isinstance(max_new_tokens, list):
230
max_new_tokens = [max_new_tokens] * len(prompts)
231
- hf_model = AutoModelForCausalLM.from_pretrained(model)
232
+ hf_model = AutoModelForCausalLM.from_pretrained(model,
233
+ torch_dtype=torch.float16)
234
hf_tokenizer = AutoTokenizer.from_pretrained(model)
235
236
results = []
0 commit comments