Skip to content

Commit c4676de

Browse files
authored
Record traces (#47)
1 parent 5298e02 commit c4676de

File tree

1 file changed

+55
-34
lines changed

1 file changed

+55
-34
lines changed

experiments/run_experiments.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
import itertools
44
import functools
55

6-
home = "/home/cpuhrsch"
7-
8-
sam_path = "/scratch/cpuhrsch/dev/segment-anything"
96
sam_commits = {
107
"default": "6fdee8f2727f4506cfbbe553e23b895e27956588",
118
"graphbreaks": "55f772f77864752f2e98a6fc7713b45a1843c167",
@@ -19,7 +16,8 @@
1916
"wip-flash-sdpa-decoder": "bb1c8b6f3749b1a5f31635f5d2f26bcafa9d94f9"}
2017

2118

22-
def change_sam_commit(commit_name):
19+
20+
def change_sam_commit(sam_path, commit_name):
2321
assert commit_name in sam_commits
2422
root_cmd = ["git", "-C", sam_path]
2523
result = subprocess.run(
@@ -31,11 +29,12 @@ def change_sam_commit(commit_name):
3129

3230

3331
def run_experiment(experiments_data,
32+
sam_path,
33+
model_type,
3434
idx,
3535
sam_commit_name,
36-
model_type,
37-
batch_size,
38-
num_workers,
36+
batch_size=1,
37+
num_workers=0,
3938
use_half=None,
4039
use_compile="False",
4140
compress=None,
@@ -68,7 +67,7 @@ def run_experiment(experiments_data,
6867
if sam_commit_name == "local-fork":
6968
args = args + ["--use_local_sam_fork", "True"]
7069
else:
71-
change_sam_commit(sam_commit_name)
70+
change_sam_commit(sam_path, sam_commit_name)
7271
if use_half:
7372
args = args + ["--use_half", use_half]
7473
if compress is not None:
@@ -107,16 +106,14 @@ def run_experiment(experiments_data,
107106
print(prefix + "," + result.stdout.decode().split("\n")[-2])
108107

109108

110-
def run_traces(*args, **kwargs):
109+
def run_traces_fn(traces_dir, pytorch_path, rexp, *args, **kwargs):
111110
# Limit to 10 batches
112111
kwargs['limit'] = 160
113-
# Folder to save results to
114-
traces_dir = "/home/cpuhrsch/tmp/traces/20230924"
115112

116113
# Create kernel traces
117114
profile_path = f"{traces_dir}/{args[0]}.json.gz"
118115
kwargs['profile_path'] = profile_path
119-
run_experiment(*args, **kwargs)
116+
rexp(*args, **kwargs)
120117
kwargs['profile_path'] = None
121118

122119
# Don't print header again if already printed
@@ -129,41 +126,65 @@ def run_traces(*args, **kwargs):
129126

130127
memory_path = f"{traces_dir}/{args[0]}"
131128
kwargs['memory_path'] = memory_path + ".pickle"
132-
run_experiment(*args, **kwargs)
129+
rexp(*args, **kwargs)
133130
kwargs['memory_path'] = None
134131

135132
# Convert memory trace to html page
136-
conversion_cmd = ["python", "/home/cpuhrsch/dev/pytorch/torch/cuda/_memory_viz.py",
133+
conversion_cmd = ["python", f"{pytorch_path}/torch/cuda/_memory_viz.py",
137134
"trace_plot", memory_path + ".pickle", "-o", memory_path + ".html"]
138135
result = subprocess.run(conversion_cmd, capture_output=True)
139136
assert result.returncode == 0
140137

141-
def run(experiments_data=None):
138+
def run(batch_size,
139+
model,
140+
experiments_data=None,
141+
run_traces=False,
142+
run_experiments=False,
143+
traces_dir=None,
144+
num_workers=32,
145+
print_header=True):
146+
147+
pytorch_path = "/home/cpuhrsch/dev/pytorch"
148+
sam_path = "/home/cpuhrsch/dev/segment-anything"
149+
assert model == "vit_b" or model == "vit_h"
150+
142151
if experiments_data is None:
143152
experiments_data = "experiments_data"
144153

145-
# run_traces("fp32", "default", "vit_b", 16, 32, print_header=True)
146-
# run_traces("fp16", "codesign", "vit_b", 16, 32, use_half=True)
147-
# run_traces("compile", "codesign", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
148-
# run_traces("SDPA", "sdpa-decoder", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
149-
# run_traces("Triton", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
150-
# run_traces("NT", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True)
151-
# run_traces("int8", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
152-
# run_traces("sparse", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
154+
rexp = functools.partial(run_experiment,
155+
experiments_data,
156+
sam_path,
157+
model,
158+
batch_size=batch_size,
159+
num_workers=num_workers)
153160

154-
rexp = functools.partial(run_experiment, experiments_data)
155161
print_header = True
156-
for bs, model in itertools.product([1, 32], ["vit_b", "vit_h"]):
157-
# rexp("fp32", "default", model, bs, 32, print_header=print_header)
162+
if run_traces:
163+
assert traces_dir is not None
164+
rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp)
165+
166+
rt("fp32", "default", capture_output=False)
167+
rt("fp16", "codesign", use_half="bfloat16")
168+
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
169+
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
170+
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
171+
if batch_size > 1:
172+
rt("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True)
173+
rt("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
174+
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
175+
176+
if run_experiments:
177+
rexp("fp32", "default", print_header=print_header, capture_output=False)
158178
print_header = False
159-
# rexp("bf16", "codesign", model, bs, 32, use_half="bfloat16")
160-
# rexp("compile", "codesign", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
161-
# rexp("SDPA", "sdpa-decoder", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
162-
rexp("Triton", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", capture_output=False)
163-
if bs > 1:
164-
rexp("NT", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
165-
rexp("int8", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
166-
rexp("sparse", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")
179+
rexp("bf16", "codesign", use_half="bfloat16")
180+
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
181+
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
182+
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
183+
if batch_size > 1:
184+
rexp("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
185+
rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
186+
rexp("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")
187+
167188

168189
if __name__ == '__main__':
169190
fire.Fire(run)

0 commit comments

Comments
 (0)