Skip to content

Commit a05bf7f

Browse files
🚨 set dtype=float16 for tests
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
1 parent 6a4c991 commit a05bf7f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/spyre_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import openai
1111
import pytest
1212
import requests
13+
import torch
1314
from sentence_transformers import SentenceTransformer, util
1415
from transformers import AutoModelForCausalLM, AutoTokenizer
1516
from vllm import LLM, SamplingParams
@@ -228,7 +229,8 @@ def generate_hf_output(
228229
if not isinstance(max_new_tokens, list):
229230
max_new_tokens = [max_new_tokens] * len(prompts)
230231

231-
hf_model = AutoModelForCausalLM.from_pretrained(model)
232+
hf_model = AutoModelForCausalLM.from_pretrained(model,
233+
torch_dtype=torch.float16)
232234
hf_tokenizer = AutoTokenizer.from_pretrained(model)
233235

234236
results = []

0 commit comments

Comments
 (0)