Skip to content

Commit 6669637

Browse files
authored
Added mse range setting
Differential Revision: D77055545 Pull Request resolved: #11857
1 parent 2f55193 commit 6669637

File tree

1 file changed

+58
-9
lines changed

1 file changed

+58
-9
lines changed

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
annotate_matmul_16a8w,
2121
)
2222

23+
from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
24+
PerChannelParamObserver,
25+
)
26+
from executorch.backends.qualcomm.quantizer.qconfig import (
27+
_derived_bias_quant_spec,
28+
QuantizationConfig,
29+
)
30+
2331
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
2432
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
2533

@@ -47,6 +55,8 @@
4755

4856
from torchao.quantization.pt2e import MinMaxObserver
4957
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
58+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
59+
5060

5161
sys.setrecursionlimit(4096)
5262
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -78,6 +88,33 @@ def forward(
7888
return self.model.forward(tokens, self.atten_mask)
7989

8090

91+
def add_mse_weight_observer(quant_dtype, quantizer):
92+
weight_dtype = (
93+
torch.int4
94+
if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block)
95+
else torch.int8
96+
)
97+
per_channel_q_config = quantizer.default_quant_config.quant_config
98+
weight_qspec = QuantizationSpec(
99+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
100+
quant_min=(
101+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
102+
),
103+
quant_max=(7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max),
104+
qscheme=torch.per_channel_symmetric,
105+
ch_axis=0,
106+
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
107+
**{"steps": 200, "use_mse": True}
108+
),
109+
)
110+
quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig(
111+
input_activation=per_channel_q_config.input_activation,
112+
output_activation=per_channel_q_config.output_activation,
113+
weight=weight_qspec,
114+
bias=_derived_bias_quant_spec,
115+
)
116+
117+
81118
def gen_eval_wrapper(model_name, args):
82119
tokenizer = get_tokenizer(args.tokenizer_path)
83120
with open(args.params) as f:
@@ -142,13 +179,13 @@ def permute(w, heads):
142179
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
143180
layer.feed_forward.prepare_feedfoward_conv()
144181

145-
model.to(dtype=torch.bfloat16)
182+
model.to(dtype=torch.float)
146183
model.to(device=args.device)
147184

148185
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
149186
tokens = tokens.to(device=args.device)
150187
atten_mask = atten_mask.to(device=args.device)
151-
atten_mask = atten_mask.to(dtype=torch.bfloat16)
188+
atten_mask = atten_mask.to(dtype=torch.float)
152189
inputs = (tokens, atten_mask)
153190

154191
if args.embedding_quantize:
@@ -174,7 +211,8 @@ def permute(w, heads):
174211
)
175212
quantizer.add_custom_quant_annotations(custom_annotations)
176213

177-
model.has_quant_io = True
214+
if args.range_setting == "mse_weight":
215+
add_mse_weight_observer(quant_dtype, quantizer)
178216

179217
with torch.no_grad():
180218
model = torch.export.export(model, inputs, strict=True).module()
@@ -245,6 +283,23 @@ def main() -> None:
245283
torch.manual_seed(seed)
246284
modelname = "llama2"
247285
parser = build_args_parser()
286+
parser.add_argument(
287+
"-P",
288+
"--ptq",
289+
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.",
290+
type=str,
291+
)
292+
parser.add_argument(
293+
"--range_setting",
294+
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
295+
type=str,
296+
)
297+
parser.add_argument(
298+
"--limit",
299+
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
300+
type=str,
301+
)
302+
248303
args = parser.parse_args()
249304
args.llama_model = "llama3_2"
250305
# Overrides this arg, because evaluation requires full logits.
@@ -257,15 +312,9 @@ def main() -> None:
257312
args.use_kv_cache = False
258313
args.prefill_ar_len = args.max_seq_length
259314

260-
# To do fewer samples for faster evaluation
261-
args.limit = 0.1
262-
# args.samples = {'wikitext': list(range(1))}
263-
264315
args.device = "cuda" if torch.cuda.is_available() else "cpu"
265316
torch.set_default_device(args.device)
266317

267-
args.ptq = "8a8w"
268-
269318
eval_llama(modelname, args)
270319

271320

0 commit comments

Comments
 (0)