5
5
from data import build_data , setup_coco_img_ids
6
6
import math
7
7
import segment_anything_fast
8
+ import torchao
8
9
9
10
torch ._dynamo .config .cache_size_limit = 50000
10
11
@@ -289,7 +290,7 @@ def run(
289
290
profile_top = False ,
290
291
memory_path = None ,
291
292
use_local_sam_fork = False ,
292
- use_compiler_settings = False ,
293
+ use_compiler_settings = True ,
293
294
):
294
295
from torch ._inductor import config as inductorconfig
295
296
inductorconfig .triton .unique_kernel_names = True
@@ -298,6 +299,7 @@ def run(
298
299
if use_compiler_settings :
299
300
# inductorconfig.fx_graph_cache = True # seems to slow performance
300
301
inductorconfig .epilogue_fusion = False
302
+ torch ._dynamo .config .automatic_dynamic_shapes = False
301
303
inductorconfig .coordinate_descent_tuning = True
302
304
inductorconfig .coordinate_descent_check_all_directions = True
303
305
@@ -336,7 +338,12 @@ def run(
336
338
for block in predictor .model .image_encoder .blocks :
337
339
block .attn .use_rel_pos = use_rel_pos
338
340
339
- if compress == "dynamic_quant" :
341
+ if compress == "autoquant" :
342
+ example_input = torch .randn ((batch_size , 3 , 1024 , 1024 ), dtype = use_half , device = "cuda" )
343
+ inductorconfig .force_fuse_int_mm_with_mul = True
344
+ inductorconfig .use_mixed_mm = True
345
+ torchao .autoquant (predictor .model .image_encoder , example_input , mode = ["interpolate" , .5 ])
346
+ elif compress == "dynamic_quant" :
340
347
from torchao .quantization import apply_dynamic_quant
341
348
apply_dynamic_quant (predictor .model .image_encoder )
342
349
inductorconfig .force_fuse_int_mm_with_mul = True
0 commit comments