Skip to content

Commit d963a88

Browse files
authored
Fix generate.py for fbgemm int4 integration (#2273)
Summary: Updated table: | | overall tokens/sec | TTFT | Peak Memory | Model Size | | ---------| -------------------| ------| --------------| -----------| | baseline - 1 | 131.65 | 0.0220 | 16.24 GB | 15.01 GB | | baseline - 128| 76.38 | 0.0544 | 26.92 GB | 15.01 GB| | int4wo - 1 | 207.69 | 0.0288 | 6.41 GB | 3.99 GB | | int4wo - 128 | 12.85 | 0.4223 | 16.01 GB | 3.99 GB | | fbgemm-int4 - 1 (w/ compile) | 61.12 | 0.0212 | 7.59 GB | 3.00 GB | | fbgemm-int4 - 128 (w/ compile) | 71.23 | 0.0576 | 16.13 GB | 3.99 GB | Verified that compile works: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization fbgemm-int4-128 --batch_size 1 --compile ========== Average overall tokens/sec: 61.12 Average decode tokens/sec: 61.5512 s Average TTFT: 0.0212 s Average tokens/sec: 61.12 Average Bandwidth: 243.70 GB/s Peak Memory Usage: 7.59 GB Model Size: 3.99 GB python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization fbgemm-int4-128 --batch_size 128 --compile ========== Average overall tokens/sec: 71.23 Average decode tokens/sec: 72.8871 s Average TTFT: 0.0576 s Average tokens/sec: 71.23 Average tokens/sec including batches 9116.81 Average Bandwidth: 284.00 GB/s Peak Memory Usage: 16.13 GB Model Size: 3.99 GB Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent dacd3aa commit d963a88

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchao/_models/llama/generate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,14 @@ def ffn_or_attn_only(mod, fqn):
447447

448448
_, precision, group_size = quantization.split("-")
449449
group_size = int(group_size)
450+
block_size = [1, group_size]
450451
if precision == "int4":
451-
quantize_(model, FbgemmConfig("bf16i4bf16", group_size))
452+
quantize_(
453+
model,
454+
FbgemmConfig(
455+
torch.bfloat16, torch.int4, torch.bfloat16, block_size
456+
),
457+
)
452458
else:
453459
raise NotImplementedError(
454460
f"FbegemmConfig({precision=}) not supported yet"

0 commit comments

Comments
 (0)