File tree Expand file tree Collapse file tree 2 files changed +9
-4
lines changed
src/llmcompressor/modeling Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Original file line number Diff line number Diff line change 1
1
from datasets import load_dataset
2
- from transformers import AutoModelForCausalLM , AutoTokenizer
2
+ from transformers import AutoConfig , AutoModelForCausalLM , AutoTokenizer
3
3
4
4
from llmcompressor .modeling import prepare_for_calibration
5
5
from llmcompressor .modifiers .quantization import GPTQModifier
9
9
# For DeepSeek-R1, we require a full precision model in order to properly calibrate
10
10
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16
11
11
model_id = "unsloth/DeepSeek-R1-0528-BF16"
12
- model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = "auto" )
12
+ config = AutoConfig .from_pretrained (model_id )
13
+ del config .quantization_config # fp8 qconfig no longer appplies to bf16 model
14
+ model = AutoModelForCausalLM .from_pretrained (
15
+ model_id , torch_dtype = "auto" , config = config
16
+ )
13
17
tokenizer = AutoTokenizer .from_pretrained (model_id )
14
18
model = prepare_for_calibration (model )
15
19
Original file line number Diff line number Diff line change 13
13
14
14
def prepare_for_calibration (model : PreTrainedModel ) -> PreTrainedModel :
15
15
def replace (module : torch .nn .Module ) -> torch .nn .Module :
16
- if module .__class__ .__name__ in replacements :
17
- return replacements [module .__class__ ](module )
16
+ cls_name = module .__class__ .__name__
17
+ if cls_name in replacements :
18
+ return replacements [cls_name ](module )
18
19
else :
19
20
return module
20
21
You can’t perform that action at this time.
0 commit comments