Skip to content

Commit 7c3c51f

Browse files
authored
Some improvement to make autoquant v2 work with Mixtral-8x7B-v0.1 (#1328)
* Some improvement to make autoquant v2 work with Mixtral-8x7B-v0.1 Summary: Tested locally running autoquant v2 with llama2-7b and Mixtral-8x7B-v0.1 in https://github.com/pytorch/pytorch/blob/main/benchmarks/gpt_fast/benchmark.py Llama-2-7b-chat-hf: Compilation time: 81.71 seconds Average tokens/sec: 131.12 tokens/sec Average bandwidth achieved: 1732.77 GB/s Memory used: 27.71 GB Mixtral-8x7B-v0.1: Compilation time: 108.89 seconds Average tokens/sec: 79.59 tokens/sec Average bandwidth achieved: 1025.14 GB/s Memory used: 61.62 GB more result can be found in pytorch/pytorch#140627 Test Plan: local test with pytorch/benchmarks/gpt_fast/benchmark.py Reviewers: Subscribers: Tasks: Tags: * remove print
1 parent f3c1a00 commit 7c3c51f

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@
5656

5757
target_folder = "/home/jerryzh/local/tmp/20241104_dynamo_test"
5858

59-
prepare_target_folder(target_folder)
60-
61-
6259
__all__ = [
6360
"AutoQuantizableLinearWeight",
6461
"autoquant_v2",
@@ -128,29 +125,36 @@ def update_cache(gm, cls, shapes_and_dtype, res):
128125

129126
# adjust each input's bsz to target_bsz
130127
# enable grad
128+
# a hacky solution but should work in the use cases we are testing now
129+
# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
131130
def resize_input(t, extracted_bsz, target_bsz):
132131
if len(t.shape) > 1:
133-
old_first_dim, old_second_dim, old_rest = t.size()[0], t.size()[1], t.size()[2:]
134-
assert old_first_dim == 1
135-
assert (
136-
old_second_dim % extracted_bsz == 0
137-
), f"unexpected old_first_dim {old_first_dim} target_bsz {target_bsz}"
138-
new_second_dim = old_second_dim // extracted_bsz * target_bsz
139-
new_shape = (old_first_dim, new_second_dim, *old_rest)
132+
new_shape = []
133+
for i in range(len(t.size())):
134+
if t.size(i) == extracted_bsz:
135+
new_shape.append(target_bsz)
136+
else:
137+
new_shape.append(t.size(i))
140138
t = torch.randn(*new_shape, dtype=t.dtype, device=t.device)
141139
return t
142140

143141

142+
# a hacky solution but should work in the use cases we are testing now
143+
# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
144144
def maybe_adjust_model_bsz(m, extracted_bsz, target_bsz):
145145
"""
146146
Makes guesses on how to adjust the model graph to account for the
147147
fact that we changed the batch size. Note: this is very brittle
148148
"""
149149
for n in m.graph.nodes:
150150
if n.op == "call_method" and n.target == "view":
151-
if n.args[2] == extracted_bsz:
152-
new_args = (*n.args[:2], target_bsz, *n.args[3:])
153-
n.args = new_args
151+
new_args = []
152+
for arg in n.args:
153+
if arg == extracted_bsz:
154+
new_args.append(target_bsz)
155+
else:
156+
new_args.append(arg)
157+
n.args = tuple(new_args)
154158

155159
m.recompile()
156160

@@ -181,6 +185,7 @@ def __new__(
181185
fqn=None,
182186
example_inputs=None,
183187
fqn_to_submodule=None,
188+
batch_size=None,
184189
**kwargs,
185190
):
186191
kwargs["device"] = weight.device
@@ -204,6 +209,7 @@ def __init__(
204209
fqn=None,
205210
example_inputs=None,
206211
fqn_to_submodule=None,
212+
batch_size=None,
207213
**kwargs,
208214
):
209215
self.weight = weight
@@ -214,6 +220,7 @@ def __init__(
214220
self.fqn = fqn
215221
self.example_inputs = example_inputs
216222
self.fqn_to_submodule = fqn_to_submodule
223+
self.batch_size = batch_size
217224

218225
def __repr__(self):
219226
return (
@@ -236,7 +243,7 @@ def log_shape(act_mat, w_autoquant, bias):
236243
)
237244

238245
def tune_autoquant2(
239-
self, fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape
246+
self, fqn, m, batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape
240247
):
241248
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
242249

@@ -248,8 +255,8 @@ def tune_autoquant2(
248255
linear_module = module
249256
weight = q_cls.from_float(linear_module.weight)
250257
linear_module.weight = torch.nn.Parameter(weight, requires_grad=False)
251-
if LLAMA:
252-
extracted_bsz = 256
258+
if batch_size is not None:
259+
extracted_bsz = batch_size
253260
target_bsz = act_shape[0]
254261
inputs = tree_map(
255262
lambda t: resize_input(t, extracted_bsz, target_bsz), inputs
@@ -329,7 +336,7 @@ def count_shapes(self, do_print=True):
329336
else time_for_best_shape
330337
)
331338
self.tune_autoquant2(
332-
fqn, m, inputs, q_cls, shapes_and_dtype, time_for_best_shape
339+
fqn, m, self.batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape
333340
)
334341
ran_new_benchmarks = True
335342
torch._dynamo.reset()
@@ -368,6 +375,7 @@ def _apply_fn_to_data(self, fn):
368375
fqn=self.fqn,
369376
example_inputs=self.example_inputs,
370377
fqn_to_submodule=self.fqn_to_submodule,
378+
batch_size=self.batch_size,
371379
)
372380

373381
def __tensor_flatten__(self):
@@ -378,6 +386,7 @@ def __tensor_flatten__(self):
378386
self.fqn,
379387
self.example_inputs,
380388
self.fqn_to_submodule,
389+
self.batch_size,
381390
self.dtype,
382391
self.shape,
383392
]
@@ -394,6 +403,7 @@ def __tensor_unflatten__(
394403
fqn,
395404
example_inputs,
396405
fqn_to_submodule,
406+
batch_size,
397407
dtype,
398408
shape,
399409
) = tensor_attributes
@@ -405,6 +415,7 @@ def __tensor_unflatten__(
405415
fqn=fqn,
406416
example_inputs=example_inputs,
407417
fqn_to_submodule=fqn_to_submodule,
418+
batch_size=batch_size,
408419
shape=shape if outer_size is None else outer_size,
409420
dtype=dtype,
410421
strides=outer_stride,
@@ -480,16 +491,6 @@ def do_autoquant_bench(op, *args, **kwargs):
480491
return res
481492

482493

483-
@torch.no_grad()
484-
def do_autoquant_bench2(model, *args, **kwargs):
485-
rep = kwargs.pop("rep", 200)
486-
warmup = kwargs.pop("warmup", 30)
487-
488-
torch._dynamo.reset()
489-
benchmark_model(model, warmup, args, kwargs)
490-
return benchmark_model(model, rep, args, kwargs)
491-
492-
493494
def _is_interpolate_mode(mode):
494495
if (
495496
isinstance(mode, list)
@@ -997,7 +998,7 @@ def dict_union(*args):
997998

998999

9991000
def _change_linears_to_autoquantizable(
1000-
model, example_input, fqn_to_submodule, **kwargs
1001+
model, example_input, fqn_to_submodule, batch_size, **kwargs
10011002
):
10021003
"""
10031004
Converts all linear weight tensors to the
@@ -1017,6 +1018,7 @@ def _change_linears_to_autoquantizable(
10171018
kwargs["model"] = model
10181019
kwargs["example_inputs"] = example_input
10191020
kwargs["fqn_to_submodule"] = fqn_to_submodule
1021+
kwargs["batch_size"] = batch_size
10201022
from torchao.quantization.quant_api import _get_subclass_inserter
10211023

10221024
_replace_with_custom_fn_if_matches_filter(
@@ -1090,6 +1092,7 @@ def autoquant_v2(
10901092
manual=False,
10911093
set_inductor_config=True,
10921094
supress_autoquant_errors=True,
1095+
batch_size=None,
10931096
**aq_kwargs,
10941097
):
10951098
"""
@@ -1151,6 +1154,7 @@ def autoquant_v2(
11511154

11521155
assert example_input is not None
11531156

1157+
prepare_target_folder(target_folder)
11541158
torch._dynamo.reset()
11551159
# TODO: explore using node.meta to retrieve the subgraph and fqn information
11561160
# disable nn module inlining, our subgraph extraction logic depends on this
@@ -1168,6 +1172,8 @@ def autoquant_v2(
11681172
else:
11691173
raise Exception("Unexpected example_input:", example_input)
11701174

1175+
torch._inductor.config.pre_grad_custom_pass = None
1176+
11711177
# verify debug logs and summary got saved
11721178
assert os.path.isfile(
11731179
os.path.join(target_folder, "debug_logs_0.txt")
@@ -1221,6 +1227,7 @@ def autoquant_v2(
12211227
model,
12221228
example_input,
12231229
fqn_to_submodule,
1230+
batch_size,
12241231
filter_fn=filter_fn,
12251232
qtensor_class_list=qtensor_class_list,
12261233
mode=mode,

0 commit comments

Comments
 (0)