Skip to content

Commit 6895af1

Browse files
authored
[Bugfix] Fix gemma2 generation (#1552)
## Purpose ## * Fix gemma2 generation * See #1517 Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 8b2c612 commit 6895af1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/quantization_w8a8_int8/gemma2_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ def tokenize(sample):
7070
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
7171
# Consider either downgrading your transformers version to a previous version
7272
# or use vLLM for sample generation.
73+
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
7374
print("========== SAMPLE GENERATION ==============")
7475
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
75-
output = model.generate(input_ids, max_new_tokens=20)
76+
output = model.generate(input_ids, max_new_tokens=20, disable_compile=True)
7677
print(tokenizer.decode(output[0]))
7778
print("==========================================")
7879

0 commit comments

Comments
 (0)