Skip to content

Commit 5f9af3c

Browse files
authored
[Bugfix] [Examples] Fix remaining Gemma generations (#1604)
## Purpose ## * Fix generation for gemma models * See: #1517 ## Changes ## * Disable compilation for remaining gemma models --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 95e8bdc commit 5f9af3c

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/quantization_kv_cache/gemma2_fp8_kv_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ def process_and_tokenize(example):
8787
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
8888
# Consider either downgrading your transformers version to a previous version
8989
# or use vLLM for sample generation.
90+
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
9091
print("\n\n")
9192
dispatch_for_generation(model)
9293
print("========== SAMPLE GENERATION ==============")
9394
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
94-
output = model.generate(input_ids, max_new_tokens=100)
95+
output = model.generate(input_ids, max_new_tokens=100, disable_compile=True)
9596
print(tokenizer.decode(output[0]))
9697
print("==========================================\n\n")
9798

examples/quantization_w8a8_fp8/gemma2_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
3030
# Consider either downgrading your transformers version to a previous version
3131
# or use vLLM for sample generation.
32+
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
3233
print("========== SAMPLE GENERATION ==============")
3334
dispatch_for_generation(model)
3435
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
35-
output = model.generate(input_ids, max_new_tokens=20)
36+
output = model.generate(input_ids, max_new_tokens=20, disable_compile=True)
3637
print(tokenizer.decode(output[0]))
3738
print("==========================================")
3839

0 commit comments

Comments
 (0)