Skip to content

Commit 6800f81

Browse files
shanjiazdsikka
andauthored
[BugFix] Fix quantizaiton_2of4_sparse_w4a16 example (#1565)
SUMMARY: Pass saved model directly to next stage since now `train` and `oneshot` can properly initialize models from path. TEST PLAN: Testing `test_quantization_2of4_sparse_w4a16` locally ``` collected 2 items tests/examples/test_quantization_2of4_sparse_w4a16.py::TestQuantization24SparseW4A16::test_doc_example_commPASSED tests/examples/test_quantization_2of4_sparse_w4a16.py::TestQuantization24SparseW4A16::test_alternative_recipe PASSED =========================================================== 2 passed in 6123.28s (1:42:03) =========================================================== ``` --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent a144d8b commit 6800f81

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from pathlib import Path
2+
13
import torch
24
from loguru import logger
35
from transformers import AutoModelForCausalLM, AutoTokenizer
46

57
from llmcompressor import oneshot, train
6-
from llmcompressor.utils import dispatch_for_generation
78

89
# load the model in as bfloat16 to save on memory and compute
910
model_stub = "neuralmagic/Llama-2-7b-ultrachat200k"
@@ -18,6 +19,7 @@
1819

1920
# save location of quantized model
2021
output_dir = "output_llama7b_2of4_w4a16_channel"
22+
output_path = Path(output_dir)
2123

2224
# set dataset config parameters
2325
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
@@ -63,25 +65,26 @@
6365
# ./output_llama7b_2of4_w4a16_channel/ + (finetuning/sparsity/quantization)_stage
6466

6567
# Oneshot sparsification
66-
oneshot_applied_model = oneshot(
68+
69+
oneshot(
6770
model=model,
6871
**oneshot_kwargs,
72+
output_dir=output_dir,
6973
stage="sparsity_stage",
7074
)
7175

7276
# Sparse finetune
73-
dispatch_for_generation(model)
74-
finetune_applied_model = train(
75-
model=oneshot_applied_model,
77+
train(
78+
model=(output_path / "sparsity_stage"),
7679
**oneshot_kwargs,
7780
**training_kwargs,
81+
output_dir=output_dir,
7882
stage="finetuning_stage",
7983
)
8084

8185
# Oneshot quantization
82-
model.to("cpu")
8386
quantized_model = oneshot(
84-
model=finetune_applied_model,
87+
model=(output_path / "finetuning_stage"),
8588
**oneshot_kwargs,
8689
stage="quantization_stage",
8790
)

0 commit comments

Comments
 (0)