3
3
import itertools
4
4
import functools
5
5
6
- home = "/home/cpuhrsch"
7
-
8
- sam_path = "/scratch/cpuhrsch/dev/segment-anything"
9
6
sam_commits = {
10
7
"default" : "6fdee8f2727f4506cfbbe553e23b895e27956588" ,
11
8
"graphbreaks" : "55f772f77864752f2e98a6fc7713b45a1843c167" ,
19
16
"wip-flash-sdpa-decoder" : "bb1c8b6f3749b1a5f31635f5d2f26bcafa9d94f9" }
20
17
21
18
22
- def change_sam_commit (commit_name ):
19
+
20
+ def change_sam_commit (sam_path , commit_name ):
23
21
assert commit_name in sam_commits
24
22
root_cmd = ["git" , "-C" , sam_path ]
25
23
result = subprocess .run (
@@ -31,11 +29,12 @@ def change_sam_commit(commit_name):
31
29
32
30
33
31
def run_experiment (experiments_data ,
32
+ sam_path ,
33
+ model_type ,
34
34
idx ,
35
35
sam_commit_name ,
36
- model_type ,
37
- batch_size ,
38
- num_workers ,
36
+ batch_size = 1 ,
37
+ num_workers = 0 ,
39
38
use_half = None ,
40
39
use_compile = "False" ,
41
40
compress = None ,
@@ -68,7 +67,7 @@ def run_experiment(experiments_data,
68
67
if sam_commit_name == "local-fork" :
69
68
args = args + ["--use_local_sam_fork" , "True" ]
70
69
else :
71
- change_sam_commit (sam_commit_name )
70
+ change_sam_commit (sam_path , sam_commit_name )
72
71
if use_half :
73
72
args = args + ["--use_half" , use_half ]
74
73
if compress is not None :
@@ -107,16 +106,14 @@ def run_experiment(experiments_data,
107
106
print (prefix + "," + result .stdout .decode ().split ("\n " )[- 2 ])
108
107
109
108
110
- def run_traces ( * args , ** kwargs ):
109
+ def run_traces_fn ( traces_dir , pytorch_path , rexp , * args , ** kwargs ):
111
110
# Limit to 10 batches
112
111
kwargs ['limit' ] = 160
113
- # Folder to save results to
114
- traces_dir = "/home/cpuhrsch/tmp/traces/20230924"
115
112
116
113
# Create kernel traces
117
114
profile_path = f"{ traces_dir } /{ args [0 ]} .json.gz"
118
115
kwargs ['profile_path' ] = profile_path
119
- run_experiment (* args , ** kwargs )
116
+ rexp (* args , ** kwargs )
120
117
kwargs ['profile_path' ] = None
121
118
122
119
# Don't print header again if already printed
@@ -129,41 +126,65 @@ def run_traces(*args, **kwargs):
129
126
130
127
memory_path = f"{ traces_dir } /{ args [0 ]} "
131
128
kwargs ['memory_path' ] = memory_path + ".pickle"
132
- run_experiment (* args , ** kwargs )
129
+ rexp (* args , ** kwargs )
133
130
kwargs ['memory_path' ] = None
134
131
135
132
# Convert memory trace to html page
136
- conversion_cmd = ["python" , "/home/cpuhrsch/dev/pytorch /torch/cuda/_memory_viz.py" ,
133
+ conversion_cmd = ["python" , f" { pytorch_path } /torch/cuda/_memory_viz.py" ,
137
134
"trace_plot" , memory_path + ".pickle" , "-o" , memory_path + ".html" ]
138
135
result = subprocess .run (conversion_cmd , capture_output = True )
139
136
assert result .returncode == 0
140
137
141
- def run (experiments_data = None ):
138
+ def run (batch_size ,
139
+ model ,
140
+ experiments_data = None ,
141
+ run_traces = False ,
142
+ run_experiments = False ,
143
+ traces_dir = None ,
144
+ num_workers = 32 ,
145
+ print_header = True ):
146
+
147
+ pytorch_path = "/home/cpuhrsch/dev/pytorch"
148
+ sam_path = "/home/cpuhrsch/dev/segment-anything"
149
+ assert model == "vit_b" or model == "vit_h"
150
+
142
151
if experiments_data is None :
143
152
experiments_data = "experiments_data"
144
153
145
- # run_traces("fp32", "default", "vit_b", 16, 32, print_header=True)
146
- # run_traces("fp16", "codesign", "vit_b", 16, 32, use_half=True)
147
- # run_traces("compile", "codesign", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
148
- # run_traces("SDPA", "sdpa-decoder", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
149
- # run_traces("Triton", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
150
- # run_traces("NT", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True)
151
- # run_traces("int8", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
152
- # run_traces("sparse", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
154
+ rexp = functools .partial (run_experiment ,
155
+ experiments_data ,
156
+ sam_path ,
157
+ model ,
158
+ batch_size = batch_size ,
159
+ num_workers = num_workers )
153
160
154
- rexp = functools .partial (run_experiment , experiments_data )
155
161
print_header = True
156
- for bs , model in itertools .product ([1 , 32 ], ["vit_b" , "vit_h" ]):
157
- # rexp("fp32", "default", model, bs, 32, print_header=print_header)
162
+ if run_traces :
163
+ assert traces_dir is not None
164
+ rt = functools .partial (run_traces_fn , traces_dir , pytorch_path , rexp )
165
+
166
+ rt ("fp32" , "default" , capture_output = False )
167
+ rt ("fp16" , "codesign" , use_half = "bfloat16" )
168
+ rt ("compile" , "codesign" , use_half = "bfloat16" , use_compile = "max-autotune" )
169
+ rt ("SDPA" , "sdpa-decoder" , use_half = "bfloat16" , use_compile = "max-autotune" )
170
+ rt ("Triton" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" )
171
+ if batch_size > 1 :
172
+ rt ("NT" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = True )
173
+ rt ("int8" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = True , compress = "dynamic_quant" )
174
+ rt ("sparse" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = True , compress = "sparse" )
175
+
176
+ if run_experiments :
177
+ rexp ("fp32" , "default" , print_header = print_header , capture_output = False )
158
178
print_header = False
159
- # rexp("bf16", "codesign", model, bs, 32, use_half="bfloat16")
160
- # rexp("compile", "codesign", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
161
- # rexp("SDPA", "sdpa-decoder", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
162
- rexp ("Triton" , "local-fork" , model , bs , 32 , use_half = "bfloat16" , use_compile = "max-autotune" , capture_output = False )
163
- if bs > 1 :
164
- rexp ("NT" , "local-fork" , model , bs , 32 , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ))
165
- rexp ("int8" , "local-fork" , model , bs , 32 , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ), compress = "dynamic_quant" )
166
- rexp ("sparse" , "local-fork" , model , bs , 32 , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ), compress = "sparse" )
179
+ rexp ("bf16" , "codesign" , use_half = "bfloat16" )
180
+ rexp ("compile" , "codesign" , use_half = "bfloat16" , use_compile = "max-autotune" )
181
+ rexp ("SDPA" , "sdpa-decoder" , use_half = "bfloat16" , use_compile = "max-autotune" )
182
+ rexp ("Triton" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" )
183
+ if batch_size > 1 :
184
+ rexp ("NT" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ))
185
+ rexp ("int8" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ), compress = "dynamic_quant" )
186
+ rexp ("sparse" , "local-fork" , use_half = "bfloat16" , use_compile = "max-autotune" , use_nested_tensor = (bs > 1 ), compress = "sparse" )
187
+
167
188
168
189
if __name__ == '__main__' :
169
190
fire .Fire (run )
0 commit comments