Skip to content

Commit 19bd380

Browse files
committed
testing autoquant
Summary: improves runtime by 19.70 -> 19.76 img/sec ❯ one sh run.sh 0%| | 0/64 [00:00<?, ?it/s]/home/cdhernandez/local/pytorch/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at /home/cdhernandez/local/pytorch/aten/src/ATen/NestedTensorImpl.cpp:177.) return _nested.nested_tensor( 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [06:32<00:00, 6.14s/it] sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path vit_h,16,14532,17,18.861125832244333,53.01910442113876,0.5865236891447146,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,64,1024,None,None 0%| | 0/64 [00:00<?, ?it/s]/home/cdhernandez/local/pytorch/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at /home/cdhernandez/local/pytorch/aten/src/ATen/NestedTensorImpl.cpp:177.) return _nested.nested_tensor( 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [07:08<00:00, 6.70s/it] vit_h,16,14395,17,19.70834741975898,50.73992145061493,0.5875230894143607,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,64,1024,None,None <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 3.850527899339795 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.3931088875979185 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.3931088875979185 3.190660197287798 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 4.768232116475701 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 3.8598313461989164 shape=(torch.Size([78400, 1280]), torch.Size([3840, 1280]), torch.Size([3840])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 1.4865157660096884 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 1.8800818361341953 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 1.8800818361341953 1.179535873234272 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 1.7427184619009497 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 1.4965661568567157 shape=(torch.Size([78400, 1280]), torch.Size([1280, 1280]), torch.Size([1280])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQFloatLinearWeight'> <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 4.215262923389673 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.661373794078827 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.661373794078827 3.485689079388976 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 5.220260447822511 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 4.2220821138471365 shape=(torch.Size([65536, 1280]), torch.Size([5120, 1280]), torch.Size([5120])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 4.666170105338097 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.113288130611181 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 4.113288130611181 2.626298717223108 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 4.855024302378297 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 4.674202110618353 shape=(torch.Size([65536, 5120]), torch.Size([1280, 5120]), torch.Size([1280])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 3.2269158866256475 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 3.7462301552295685 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 3.7462301552295685 2.6572815608233213 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 3.9978391956537966 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 3.2370124012231827 shape=(torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> <class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 1.2530277017503977 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 1.5717314090579748 <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> 1.5717314090579748 0.9894231799989939 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight'> 1.5166664496064186 <class 'torchao.quantization.autoquant.AQWeightOnlyQuantizedLinearWeight3'> 1.2606457574293017 shape=(torch.Size([65536, 1280]), torch.Size([1280, 1280]), torch.Size([1280])), dtype=torch.bfloat16, best_cls=<class 'torchao.quantization.autoquant.AQFloatLinearWeight'> 0%| | 0/64 [00:00<?, ?it/s]/home/cdhernandez/local/pytorch/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at /home/cdhernandez/local/pytorch/aten/src/ATen/NestedTensorImpl.cpp:177.) return _nested.nested_tensor( 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [02:15<00:00, 2.12s/it] vit_h,16,14463,17,19.76190752324237,50.602402567863464,0.5875653903095147,max-autotune,torch.bfloat16,auto_quant,False,False,True,True,True,32,64,1024,None,None Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a9134d6 Pull Request resolved: #114
1 parent 387488b commit 19bd380

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

experiments/eval_combo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def run(
289289
profile_top=False,
290290
memory_path=None,
291291
use_local_sam_fork=False,
292-
use_compiler_settings=False,
292+
use_compiler_settings=True,
293293
):
294294
from torch._inductor import config as inductorconfig
295295
inductorconfig.triton.unique_kernel_names = True
@@ -298,6 +298,7 @@ def run(
298298
if use_compiler_settings:
299299
# inductorconfig.fx_graph_cache = True # seems to slow performance
300300
inductorconfig.epilogue_fusion = False
301+
torch._dynamo.config.automatic_dynamic_shapes = False
301302
inductorconfig.coordinate_descent_tuning = True
302303
inductorconfig.coordinate_descent_check_all_directions = True
303304

@@ -336,7 +337,13 @@ def run(
336337
for block in predictor.model.image_encoder.blocks:
337338
block.attn.use_rel_pos = use_rel_pos
338339

339-
if compress == "dynamic_quant":
340+
if compress == "auto_quant":
341+
from torchao.quantization.quant_api import do_autoquant
342+
example_input = torch.randn((16, 3, 1024, 1024), dtype=use_half, device="cuda")
343+
inductorconfig.force_fuse_int_mm_with_mul = True
344+
inductorconfig.use_mixed_mm = True
345+
do_autoquant(predictor.model.image_encoder, example_input)
346+
elif compress == "dynamic_quant":
340347
from torchao.quantization import apply_dynamic_quant
341348
apply_dynamic_quant(predictor.model.image_encoder)
342349
inductorconfig.force_fuse_int_mm_with_mul = True

experiments/run.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python run_experiments.py 16 vit_h \
2+
~/local/pytorch ~/local/segment-anything ~/local/sam_data \
3+
--run-experiments --local_fork_only \
4+
--num-workers 32 --capture_output False

experiments/run_experiments.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run_experiment(experiments_data,
4242
extra_args=None,
4343
print_header=False,
4444
capture_output=True,
45-
limit=None,
45+
limit=1024,
4646
profile_path=None,
4747
profile_top=False,
4848
memory_path=None):
@@ -181,6 +181,10 @@ def run(batch_size,
181181
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
182182

183183
if run_experiments:
184+
# rexp("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), print_header=print_header)
185+
rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="dynamic_quant")
186+
# rexp("auto_quant", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="auto_quant")
187+
return
184188
if local_fork_only:
185189
rexp("fp32", "local-fork", print_header=print_header)
186190
rexp("bf16", "local-fork", use_half="bfloat16")

0 commit comments

Comments
 (0)