56
56
57
57
target_folder = "/home/jerryzh/local/tmp/20241104_dynamo_test"
58
58
59
- prepare_target_folder (target_folder )
60
-
61
-
62
59
__all__ = [
63
60
"AutoQuantizableLinearWeight" ,
64
61
"autoquant_v2" ,
@@ -128,29 +125,36 @@ def update_cache(gm, cls, shapes_and_dtype, res):
128
125
129
126
# adjust each input's bsz to target_bsz
130
127
# enable grad
128
+ # a hacky solution but should work in the use cases we are testing now
129
+ # we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
131
130
def resize_input (t , extracted_bsz , target_bsz ):
132
131
if len (t .shape ) > 1 :
133
- old_first_dim , old_second_dim , old_rest = t .size ()[0 ], t .size ()[1 ], t .size ()[2 :]
134
- assert old_first_dim == 1
135
- assert (
136
- old_second_dim % extracted_bsz == 0
137
- ), f"unexpected old_first_dim { old_first_dim } target_bsz { target_bsz } "
138
- new_second_dim = old_second_dim // extracted_bsz * target_bsz
139
- new_shape = (old_first_dim , new_second_dim , * old_rest )
132
+ new_shape = []
133
+ for i in range (len (t .size ())):
134
+ if t .size (i ) == extracted_bsz :
135
+ new_shape .append (target_bsz )
136
+ else :
137
+ new_shape .append (t .size (i ))
140
138
t = torch .randn (* new_shape , dtype = t .dtype , device = t .device )
141
139
return t
142
140
143
141
142
+ # a hacky solution but should work in the use cases we are testing now
143
+ # we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz
144
144
def maybe_adjust_model_bsz (m , extracted_bsz , target_bsz ):
145
145
"""
146
146
Makes guesses on how to adjust the model graph to account for the
147
147
fact that we changed the batch size. Note: this is very brittle
148
148
"""
149
149
for n in m .graph .nodes :
150
150
if n .op == "call_method" and n .target == "view" :
151
- if n .args [2 ] == extracted_bsz :
152
- new_args = (* n .args [:2 ], target_bsz , * n .args [3 :])
153
- n .args = new_args
151
+ new_args = []
152
+ for arg in n .args :
153
+ if arg == extracted_bsz :
154
+ new_args .append (target_bsz )
155
+ else :
156
+ new_args .append (arg )
157
+ n .args = tuple (new_args )
154
158
155
159
m .recompile ()
156
160
@@ -181,6 +185,7 @@ def __new__(
181
185
fqn = None ,
182
186
example_inputs = None ,
183
187
fqn_to_submodule = None ,
188
+ batch_size = None ,
184
189
** kwargs ,
185
190
):
186
191
kwargs ["device" ] = weight .device
@@ -204,6 +209,7 @@ def __init__(
204
209
fqn = None ,
205
210
example_inputs = None ,
206
211
fqn_to_submodule = None ,
212
+ batch_size = None ,
207
213
** kwargs ,
208
214
):
209
215
self .weight = weight
@@ -214,6 +220,7 @@ def __init__(
214
220
self .fqn = fqn
215
221
self .example_inputs = example_inputs
216
222
self .fqn_to_submodule = fqn_to_submodule
223
+ self .batch_size = batch_size
217
224
218
225
def __repr__ (self ):
219
226
return (
@@ -236,7 +243,7 @@ def log_shape(act_mat, w_autoquant, bias):
236
243
)
237
244
238
245
def tune_autoquant2 (
239
- self , fqn , m , inputs , q_cls , shapes_and_dtype , time_for_best_shape
246
+ self , fqn , m , batch_size , inputs , q_cls , shapes_and_dtype , time_for_best_shape
240
247
):
241
248
act_shape , w_shape , bias_shape , act_dtype = shapes_and_dtype
242
249
@@ -248,8 +255,8 @@ def tune_autoquant2(
248
255
linear_module = module
249
256
weight = q_cls .from_float (linear_module .weight )
250
257
linear_module .weight = torch .nn .Parameter (weight , requires_grad = False )
251
- if LLAMA :
252
- extracted_bsz = 256
258
+ if batch_size is not None :
259
+ extracted_bsz = batch_size
253
260
target_bsz = act_shape [0 ]
254
261
inputs = tree_map (
255
262
lambda t : resize_input (t , extracted_bsz , target_bsz ), inputs
@@ -329,7 +336,7 @@ def count_shapes(self, do_print=True):
329
336
else time_for_best_shape
330
337
)
331
338
self .tune_autoquant2 (
332
- fqn , m , inputs , q_cls , shapes_and_dtype , time_for_best_shape
339
+ fqn , m , self . batch_size , inputs , q_cls , shapes_and_dtype , time_for_best_shape
333
340
)
334
341
ran_new_benchmarks = True
335
342
torch ._dynamo .reset ()
@@ -368,6 +375,7 @@ def _apply_fn_to_data(self, fn):
368
375
fqn = self .fqn ,
369
376
example_inputs = self .example_inputs ,
370
377
fqn_to_submodule = self .fqn_to_submodule ,
378
+ batch_size = self .batch_size ,
371
379
)
372
380
373
381
def __tensor_flatten__ (self ):
@@ -378,6 +386,7 @@ def __tensor_flatten__(self):
378
386
self .fqn ,
379
387
self .example_inputs ,
380
388
self .fqn_to_submodule ,
389
+ self .batch_size ,
381
390
self .dtype ,
382
391
self .shape ,
383
392
]
@@ -394,6 +403,7 @@ def __tensor_unflatten__(
394
403
fqn ,
395
404
example_inputs ,
396
405
fqn_to_submodule ,
406
+ batch_size ,
397
407
dtype ,
398
408
shape ,
399
409
) = tensor_attributes
@@ -405,6 +415,7 @@ def __tensor_unflatten__(
405
415
fqn = fqn ,
406
416
example_inputs = example_inputs ,
407
417
fqn_to_submodule = fqn_to_submodule ,
418
+ batch_size = batch_size ,
408
419
shape = shape if outer_size is None else outer_size ,
409
420
dtype = dtype ,
410
421
strides = outer_stride ,
@@ -480,16 +491,6 @@ def do_autoquant_bench(op, *args, **kwargs):
480
491
return res
481
492
482
493
483
- @torch .no_grad ()
484
- def do_autoquant_bench2 (model , * args , ** kwargs ):
485
- rep = kwargs .pop ("rep" , 200 )
486
- warmup = kwargs .pop ("warmup" , 30 )
487
-
488
- torch ._dynamo .reset ()
489
- benchmark_model (model , warmup , args , kwargs )
490
- return benchmark_model (model , rep , args , kwargs )
491
-
492
-
493
494
def _is_interpolate_mode (mode ):
494
495
if (
495
496
isinstance (mode , list )
@@ -997,7 +998,7 @@ def dict_union(*args):
997
998
998
999
999
1000
def _change_linears_to_autoquantizable (
1000
- model , example_input , fqn_to_submodule , ** kwargs
1001
+ model , example_input , fqn_to_submodule , batch_size , ** kwargs
1001
1002
):
1002
1003
"""
1003
1004
Converts all linear weight tensors to the
@@ -1017,6 +1018,7 @@ def _change_linears_to_autoquantizable(
1017
1018
kwargs ["model" ] = model
1018
1019
kwargs ["example_inputs" ] = example_input
1019
1020
kwargs ["fqn_to_submodule" ] = fqn_to_submodule
1021
+ kwargs ["batch_size" ] = batch_size
1020
1022
from torchao .quantization .quant_api import _get_subclass_inserter
1021
1023
1022
1024
_replace_with_custom_fn_if_matches_filter (
@@ -1090,6 +1092,7 @@ def autoquant_v2(
1090
1092
manual = False ,
1091
1093
set_inductor_config = True ,
1092
1094
supress_autoquant_errors = True ,
1095
+ batch_size = None ,
1093
1096
** aq_kwargs ,
1094
1097
):
1095
1098
"""
@@ -1151,6 +1154,7 @@ def autoquant_v2(
1151
1154
1152
1155
assert example_input is not None
1153
1156
1157
+ prepare_target_folder (target_folder )
1154
1158
torch ._dynamo .reset ()
1155
1159
# TODO: explore using node.meta to retrieve the subgraph and fqn information
1156
1160
# disable nn module inlining, our subgraph extraction logic depends on this
@@ -1168,6 +1172,8 @@ def autoquant_v2(
1168
1172
else :
1169
1173
raise Exception ("Unexpected example_input:" , example_input )
1170
1174
1175
+ torch ._inductor .config .pre_grad_custom_pass = None
1176
+
1171
1177
# verify debug logs and summary got saved
1172
1178
assert os .path .isfile (
1173
1179
os .path .join (target_folder , "debug_logs_0.txt" )
@@ -1221,6 +1227,7 @@ def autoquant_v2(
1221
1227
model ,
1222
1228
example_input ,
1223
1229
fqn_to_submodule ,
1230
+ batch_size ,
1224
1231
filter_fn = filter_fn ,
1225
1232
qtensor_class_list = qtensor_class_list ,
1226
1233
mode = mode ,
0 commit comments