Skip to content

Commit 8052351

Browse files
authored
Updated p4d results using new int_mm kernel (#61)
1 parent 92c4bda commit 8052351

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

experiments/eval_combo.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
import segment_anything_fast
88

99
torch._dynamo.config.cache_size_limit = 50000
10-
# torch._inductor.config.fx_graph_cache = True # seems to slow performance
11-
torch._inductor.config.epilogue_fusion = False
12-
torch._inductor.config.coordinate_descent_tuning = True
13-
torch._inductor.config.coordinate_descent_check_all_directions = True
14-
torch._inductor.config.force_fuse_int_mm_with_mul = True
1510

1611
def unbind_jagged(device, data, sizes, offsets):
1712
if data is None:
@@ -175,7 +170,8 @@ def build_results(batched_data_iter,
175170
use_compile,
176171
use_compile_decoder,
177172
use_nested_tensor,
178-
pad_input_image_batch):
173+
pad_input_image_batch,
174+
use_fullgraph=False):
179175

180176
# TODO: Re-enable this for datapoints
181177
assert not use_compile_decoder
@@ -197,7 +193,7 @@ def build_results(batched_data_iter,
197193
if batch_idx == 0:
198194
with torch.autograd.profiler.record_function("compilation and warmup"):
199195
if str(use_compile) != "False":
200-
predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile, fullgraph=True,)
196+
predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile, fullgraph=use_fullgraph)
201197
# Run first batch a few times for warmup and exclude it from the final timings
202198
for _ in range(3):
203199
_ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
@@ -293,10 +289,17 @@ def run(
293289
profile_top=False,
294290
memory_path=None,
295291
use_local_sam_fork=False,
292+
use_compiler_settings=False,
296293
):
297-
from torch._inductor import config as tritonconfig
298-
tritonconfig.triton.unique_kernel_names = True
299-
tritonconfig.epilogue_fusion_first = epilogue_fusion_first
294+
from torch._inductor import config as inductorconfig
295+
inductorconfig.triton.unique_kernel_names = True
296+
inductorconfig.epilogue_fusion_first = epilogue_fusion_first
297+
298+
if use_compiler_settings:
299+
# inductorconfig.fx_graph_cache = True # seems to slow performance
300+
inductorconfig.epilogue_fusion = False
301+
inductorconfig.coordinate_descent_tuning = True
302+
inductorconfig.coordinate_descent_check_all_directions = True
300303

301304
if use_half is not None:
302305
if use_half == "float16":
@@ -336,6 +339,7 @@ def run(
336339
if compress == "dynamic_quant":
337340
from segment_anything_fast.dynamic_quant import apply_dynamic_quant
338341
apply_dynamic_quant(predictor.model.image_encoder)
342+
inductorconfig.force_fuse_int_mm_with_mul = True
339343
elif compress == "static_quant":
340344
from segment_anything_fast.static_quant import apply_static_quant
341345
apply_static_quant(predictor.model.image_encoder)

experiments/p4d_results/results_bs8.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ compile,2.652463134129842,codesign,2.2.0.dev20231023+cu121,vit_b,8,7916,19,54.71
55
SDPA,2.148758562405904,sdpa-decoder,2.2.0.dev20231023+cu121,vit_b,8,4679,11,73.1570663251564,13.669219533153035,0.5355346808697282,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None
66
Triton,2.0386854648590087,local-fork,2.2.0.dev20231023+cu121,vit_b,8,1703,4,85.53658249838097,11.690904298391018,0.5339075529136259,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None
77
NT,1.9225259701410928,local-fork,2.2.0.dev20231023+cu121,vit_b,8,2797,6,92.11983959049361,10.85542489484747,0.5337810700594795,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None
8-
int8,4.885190570354462,local-fork,2.2.0.dev20231023+cu121,vit_b,8,2710,6,91.04449841914705,10.983640059130643,0.5331727804156572,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None
8+
int8,3.0942590634028115,local-fork,2.2.0.dev20231026+cu121,vit_b,8,2712,6,90.34645860305645,11.068502467745533,0.5330263012027209,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None
99
sparse,3.841790223121643,local-fork,2.2.0.dev20231023+cu121,vit_b,8,3217,7,81.4912293589238,12.271259224665185,0.4783508911148021,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None

experiments/p4d_results/results_bs8_vit_h.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ compile,7.968364950021108,codesign,2.2.0.dev20231024+cu121,vit_h,8,12358,30,19.6
55
SDPA,5.843019040425618,sdpa-decoder,2.2.0.dev20231024+cu121,vit_h,8,7947,19,21.92026495560376,45.61988653081299,0.581191777206921,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None
66
Triton,9.09047209819158,local-fork,2.2.0.dev20231024+cu121,vit_h,8,4550,11,22.874989934428537,43.71586623060877,0.5820036887609843,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None
77
NT,5.455243261655172,local-fork,2.2.0.dev20231024+cu121,vit_h,8,4550,11,23.206823845253847,43.09077393219044,0.5809004559961229,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None
8-
int8,8.623369554678598,local-fork,2.2.0.dev20231024+cu121,vit_h,8,3239,7,25.099473512089347,39.841473149559995,0.5820724009353484,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None
8+
int8,6.994769084453583,local-fork,2.2.0.dev20231026+cu121,vit_h,8,4167,10,24.87583443921619,40.19965651578395,0.5819033780783904,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None
99
sparse,5.597406772772471,local-fork,2.2.0.dev20231024+cu121,vit_h,8,7055,17,24.900183397177024,40.16034677533225,0.5289167514647479,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
version='0.2',
88
packages=packages,
99
install_requires=[
10-
'torch>=2.2.0.dev20231019',
11-
'torchvision>=0.17.0.dev20231019',
10+
'torch>=2.2.0.dev20231026',
11+
'torchvision>=0.17.0.dev20231026',
1212
'diskcache',
1313
'pycocotools',
1414
'scipy',

0 commit comments

Comments
 (0)