You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
README improvements
Summary:
quantization README:
1) added fp6 to benchmarks
2) rewrote autoquant section to give a higher level explanation before
diving into the details
3) reordered affine quantization section to first show techniques then
dive into details
4) added fp6 section
5) moved kv cache stuff to new section
6) added sparse-marlin section and removed sparse-marlin benchmark from
top of README since we don't have a reasonable flow for users to use
to apply it to their model without a pre-sparsified checkpoint.
7) added uintx section
Benchmarks Changes:
1) added instructions for adding things to benchmarks so everything
stays consistent (in llama benchmark README)
2) organized/ran benchmarks for uintx and fp6 and sparse-marlin
3) added evaluations.sh to mirror benchmarks.sh
4) added sparse-marlin to eval.py
5) fixed some generate.py logging bugs
6) improved generate help quantization help text
7) fixed some eval.py bugs with uintx
8) added marlin to eval
9) fixed eval help text
sparsity readme:
1) added some details to sparsity
Test Plan:
benchmarks.sh
evaluations.sh
Reviewers:
Subscribers:
Tasks:
Tags:
Copy file name to clipboardExpand all lines: torchao/_models/llama/README.md
+4Lines changed: 4 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -27,3 +27,7 @@ To see how these techniques scale generally we've run `generate.py` with subsets
27
27
| 32768 | 23.83 | 21.72 | 20.64 |
28
28
| 65536 | 33.5 | 29.54 | 25.24 |
29
29
| 131072 | 59.27 | 52.62 | 34.18 |
30
+
31
+
## Adding Benchmarks For New Techniques
32
+
33
+
If you want to add benchmarks that you think should be kept up to date, please try to keep the format consistent. For performance focused techniques (e.g. if they require fine-tuning or something else) add an option to run them in generate.py and an execution command in benchmarks.sh in the relevant section. If its a technique that's still in development, add it in the section for `OTHER BENCHMARKS` if there's a finalized api and you want those numbers in the main quantization README, add them in the `README BENCHMARKS` section. For accuracy focused techniques, add them in eval.py and evaluations.sh in a similar vein. Ideally techniques in the main readme will have both benchmarks and evaluations set up here so they can be monitored and reproduced easily.
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
464
469
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
465
470
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
0 commit comments