Skip to content

Commit 67d5c89

Browse files
authored
Parameterize paths for experiments (#48)
1 parent c4676de commit 67d5c89

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

experiments/run_experiments.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,20 @@ def run_traces_fn(traces_dir, pytorch_path, rexp, *args, **kwargs):
133133
conversion_cmd = ["python", f"{pytorch_path}/torch/cuda/_memory_viz.py",
134134
"trace_plot", memory_path + ".pickle", "-o", memory_path + ".html"]
135135
result = subprocess.run(conversion_cmd, capture_output=True)
136-
assert result.returncode == 0
137136

138137
def run(batch_size,
139138
model,
140-
experiments_data=None,
139+
pytorch_path,
140+
sam_path,
141+
experiments_data,
141142
run_traces=False,
142143
run_experiments=False,
143144
traces_dir=None,
144145
num_workers=32,
145146
print_header=True):
146147

147-
pytorch_path = "/home/cpuhrsch/dev/pytorch"
148-
sam_path = "/home/cpuhrsch/dev/segment-anything"
149148
assert model == "vit_b" or model == "vit_h"
150149

151-
if experiments_data is None:
152-
experiments_data = "experiments_data"
153-
154150
rexp = functools.partial(run_experiment,
155151
experiments_data,
156152
sam_path,
@@ -163,7 +159,7 @@ def run(batch_size,
163159
assert traces_dir is not None
164160
rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp)
165161

166-
rt("fp32", "default", capture_output=False)
162+
rt("fp32", "default", print_header=print_header)
167163
rt("fp16", "codesign", use_half="bfloat16")
168164
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
169165
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
@@ -174,7 +170,7 @@ def run(batch_size,
174170
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
175171

176172
if run_experiments:
177-
rexp("fp32", "default", print_header=print_header, capture_output=False)
173+
rexp("fp32", "default", print_header=print_header)
178174
print_header = False
179175
rexp("bf16", "codesign", use_half="bfloat16")
180176
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")

0 commit comments

Comments
 (0)