Skip to content

Commit 194b831

Browse files
committed
BUILD file fixes and linter run
1 parent b99fc1d commit 194b831

File tree

2 files changed

+170
-147
lines changed

2 files changed

+170
-147
lines changed

build/rocm/run_single_gpu.py

Lines changed: 170 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -25,179 +25,205 @@
2525
LAST_CODE = 0
2626
base_dir = "./logs"
2727

28+
2829
def extract_filename(path):
29-
base_name = os.path.basename(path)
30-
file_name, _ = os.path.splitext(base_name)
31-
return file_name
30+
base_name = os.path.basename(path)
31+
file_name, _ = os.path.splitext(base_name)
32+
return file_name
3233

3334

3435
def combine_json_reports():
35-
all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')]
36-
combined_data = []
37-
for json_file in all_json_files:
38-
with open(os.path.join(base_dir, json_file), 'r') as infile:
39-
data = json.load(infile)
40-
combined_data.append(data)
41-
combined_json_file = f"{base_dir}/final_compiled_report.json"
42-
with open(combined_json_file, 'w') as outfile:
43-
json.dump(combined_data, outfile, indent=4)
36+
all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
37+
combined_data = []
38+
for json_file in all_json_files:
39+
with open(os.path.join(base_dir, json_file), "r") as infile:
40+
data = json.load(infile)
41+
combined_data.append(data)
42+
combined_json_file = f"{base_dir}/final_compiled_report.json"
43+
with open(combined_json_file, "w") as outfile:
44+
json.dump(combined_data, outfile, indent=4)
4445

4546

4647
def combine_csv_reports():
47-
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')]
48-
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
49-
with open(combined_csv_file, mode='w', newline='') as outfile:
50-
csv_writer = csv.writer(outfile)
51-
for i, csv_file in enumerate(all_csv_files):
52-
with open(os.path.join(base_dir, csv_file), mode='r') as infile:
53-
csv_reader = csv.reader(infile)
54-
if i == 0:
55-
# write headers only once
56-
csv_writer.writerow(next(csv_reader))
57-
for row in csv_reader:
58-
csv_writer.writerow(row)
48+
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
49+
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
50+
with open(combined_csv_file, mode="w", newline="") as outfile:
51+
csv_writer = csv.writer(outfile)
52+
for i, csv_file in enumerate(all_csv_files):
53+
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
54+
csv_reader = csv.reader(infile)
55+
if i == 0:
56+
# write headers only once
57+
csv_writer.writerow(next(csv_reader))
58+
for row in csv_reader:
59+
csv_writer.writerow(row)
5960

6061

6162
def generate_final_report(shell=False, env_vars={}):
62-
env = os.environ
63-
env = {**env, **env_vars}
64-
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html']
65-
result = subprocess.run(cmd,
66-
shell=shell,
67-
capture_output=True,
68-
env=env)
69-
if result.returncode != 0:
70-
print("FAILED - {}".format(" ".join(cmd)))
71-
print(result.stderr.decode())
72-
73-
# Generate json reports.
74-
combine_json_reports()
75-
# Generate csv reports.
76-
combine_csv_reports()
63+
env = os.environ
64+
env = {**env, **env_vars}
65+
cmd = [
66+
"pytest_html_merger",
67+
"-i",
68+
f"{base_dir}",
69+
"-o",
70+
f"{base_dir}/final_compiled_report.html",
71+
]
72+
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
73+
if result.returncode != 0:
74+
print("FAILED - {}".format(" ".join(cmd)))
75+
print(result.stderr.decode())
76+
77+
# Generate json reports.
78+
combine_json_reports()
79+
# Generate csv reports.
80+
combine_csv_reports()
7781

7882

7983
def run_shell_command(cmd, shell=False, env_vars={}):
80-
env = os.environ
81-
env = {**env, **env_vars}
82-
result = subprocess.run(cmd,
83-
shell=shell,
84-
capture_output=True,
85-
env=env)
86-
if result.returncode != 0:
87-
print("FAILED - {}".format(" ".join(cmd)))
88-
print(result.stderr.decode())
84+
env = os.environ
85+
env = {**env, **env_vars}
86+
result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
87+
if result.returncode != 0:
88+
print("FAILED - {}".format(" ".join(cmd)))
89+
print(result.stderr.decode())
8990

90-
return result.returncode, result.stderr.decode(), result.stdout.decode()
91+
return result.returncode, result.stderr.decode(), result.stdout.decode()
9192

9293

9394
def parse_test_log(log_file):
94-
"""Parses the test module log file to extract test modules and functions."""
95-
test_files = set()
96-
with open(log_file, "r") as f:
97-
for line in f:
98-
report = json.loads(line)
99-
if "nodeid" in report:
100-
module = report["nodeid"].split("::")[0]
101-
if module and ".py" in module:
102-
test_files.add(os.path.abspath(module))
103-
return test_files
95+
"""Parses the test module log file to extract test modules and functions."""
96+
test_files = set()
97+
with open(log_file, "r") as f:
98+
for line in f:
99+
report = json.loads(line)
100+
if "nodeid" in report:
101+
module = report["nodeid"].split("::")[0]
102+
if module and ".py" in module:
103+
test_files.add(os.path.abspath(module))
104+
return test_files
104105

105106

106107
def collect_testmodules():
107-
log_file = f"{base_dir}/collect_module_log.jsonl"
108-
return_code, stderr, stdout = run_shell_command(
109-
["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"])
110-
if return_code != 0:
111-
print("Test module discovery failed.")
112-
print("STDOUT:", stdout)
113-
print("STDERR:", stderr)
114-
exit(return_code)
115-
print("---------- collected test modules ----------")
116-
test_files = parse_test_log(log_file)
117-
print("Found %d test modules." % (len(test_files)))
118-
print("--------------------------------------------")
119-
print("\n".join(test_files))
120-
return test_files
108+
log_file = f"{base_dir}/collect_module_log.jsonl"
109+
return_code, stderr, stdout = run_shell_command(
110+
[
111+
"python3",
112+
"-m",
113+
"pytest",
114+
"--collect-only",
115+
"tests",
116+
f"--report-log={log_file}",
117+
]
118+
)
119+
if return_code != 0:
120+
print("Test module discovery failed.")
121+
print("STDOUT:", stdout)
122+
print("STDERR:", stderr)
123+
exit(return_code)
124+
print("---------- collected test modules ----------")
125+
test_files = parse_test_log(log_file)
126+
print("Found %d test modules." % (len(test_files)))
127+
print("--------------------------------------------")
128+
print("\n".join(test_files))
129+
return test_files
121130

122131

123132
def run_test(testmodule, gpu_tokens, continue_on_fail):
124-
global LAST_CODE
125-
with GPU_LOCK:
126-
if LAST_CODE != 0:
127-
return
128-
target_gpu = gpu_tokens.pop()
129-
env_vars = {
130-
"HIP_VISIBLE_DEVICES": str(target_gpu),
131-
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
132-
}
133-
testfile = extract_filename(testmodule)
134-
if continue_on_fail:
135-
cmd = ["python3", "-m", "pytest",
136-
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
137-
f"--csv={base_dir}/{testfile}_log.csv",
138-
"--csv-columns", "id,module,name,file,status,duration",
139-
f"--html={base_dir}/{testfile}_log.html",
140-
"--reruns", "3", "-v", testmodule]
141-
else:
142-
cmd = ["python3", "-m", "pytest",
143-
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json",
144-
f"--csv={base_dir}/{testfile}_log.csv",
145-
"--csv-columns", "id,module,name,file,status,duration",
146-
f"--html={base_dir}/{testfile}_log.html",
147-
"--reruns", "3", "-x", "-v", testmodule]
148-
149-
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
150-
with GPU_LOCK:
151-
gpu_tokens.append(target_gpu)
152-
if LAST_CODE == 0:
153-
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
154-
print(stdout)
155-
print(stderr)
156-
if continue_on_fail == False:
157-
LAST_CODE = return_code
133+
global LAST_CODE
134+
with GPU_LOCK:
135+
if LAST_CODE != 0:
136+
return
137+
target_gpu = gpu_tokens.pop()
138+
env_vars = {
139+
"HIP_VISIBLE_DEVICES": str(target_gpu),
140+
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
141+
}
142+
testfile = extract_filename(testmodule)
143+
if continue_on_fail:
144+
cmd = [
145+
"python3",
146+
"-m",
147+
"pytest",
148+
"--json-report",
149+
f"--json-report-file={base_dir}/{testfile}_log.json",
150+
f"--csv={base_dir}/{testfile}_log.csv",
151+
"--csv-columns",
152+
"id,module,name,file,status,duration",
153+
f"--html={base_dir}/{testfile}_log.html",
154+
"--reruns",
155+
"3",
156+
"-v",
157+
testmodule,
158+
]
159+
else:
160+
cmd = [
161+
"python3",
162+
"-m",
163+
"pytest",
164+
"--json-report",
165+
f"--json-report-file={base_dir}/{testfile}_log.json",
166+
f"--csv={base_dir}/{testfile}_log.csv",
167+
"--csv-columns",
168+
"id,module,name,file,status,duration",
169+
f"--html={base_dir}/{testfile}_log.html",
170+
"--reruns",
171+
"3",
172+
"-x",
173+
"-v",
174+
testmodule,
175+
]
176+
177+
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
178+
with GPU_LOCK:
179+
gpu_tokens.append(target_gpu)
180+
if LAST_CODE == 0:
181+
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
182+
print(stdout)
183+
print(stderr)
184+
if continue_on_fail == False:
185+
LAST_CODE = return_code
158186

159187

160188
def run_parallel(all_testmodules, p, c):
161-
print(f"Running tests with parallelism = {p}")
162-
available_gpu_tokens = list(range(p))
163-
executor = ThreadPoolExecutor(max_workers=p)
164-
# walking through test modules.
165-
for testmodule in all_testmodules:
166-
executor.submit(run_test, testmodule, available_gpu_tokens, c)
167-
# waiting for all modules to finish.
168-
executor.shutdown(wait=True)
189+
print(f"Running tests with parallelism = {p}")
190+
available_gpu_tokens = list(range(p))
191+
executor = ThreadPoolExecutor(max_workers=p)
192+
# walking through test modules.
193+
for testmodule in all_testmodules:
194+
executor.submit(run_test, testmodule, available_gpu_tokens, c)
195+
# waiting for all modules to finish.
196+
executor.shutdown(wait=True)
169197

170198

171199
def find_num_gpus():
172-
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
173-
_, _, stdout = run_shell_command(cmd, shell=True)
174-
return int(stdout)
200+
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
201+
_, _, stdout = run_shell_command(cmd, shell=True)
202+
return int(stdout)
175203

176204

177205
def main(args):
178-
all_testmodules = collect_testmodules()
179-
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
180-
generate_final_report()
181-
exit(LAST_CODE)
182-
183-
184-
if __name__ == '__main__':
185-
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
186-
parser = argparse.ArgumentParser()
187-
parser.add_argument("-p",
188-
"--parallel",
189-
type=int,
190-
help="number of tests to run in parallel")
191-
parser.add_argument("-c",
192-
"--continue_on_fail",
193-
action='store_true',
194-
help="continue on failure")
195-
args = parser.parse_args()
196-
if args.continue_on_fail:
197-
print("continue on fail is set")
198-
if args.parallel is None:
199-
sys_gpu_count = find_num_gpus()
200-
args.parallel = sys_gpu_count
201-
print("%d GPUs detected." % sys_gpu_count)
202-
203-
main(args)
206+
all_testmodules = collect_testmodules()
207+
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
208+
generate_final_report()
209+
exit(LAST_CODE)
210+
211+
212+
if __name__ == "__main__":
213+
os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
214+
parser = argparse.ArgumentParser()
215+
parser.add_argument(
216+
"-p", "--parallel", type=int, help="number of tests to run in parallel"
217+
)
218+
parser.add_argument(
219+
"-c", "--continue_on_fail", action="store_true", help="continue on failure"
220+
)
221+
args = parser.parse_args()
222+
if args.continue_on_fail:
223+
print("continue on fail is set")
224+
if args.parallel is None:
225+
sys_gpu_count = find_num_gpus()
226+
args.parallel = sys_gpu_count
227+
print("%d GPUs detected." % sys_gpu_count)
228+
229+
main(args)

jaxlib/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,6 @@ pybind_extension(
251251
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
252252
"@xla//xla/python:py_client_gpu",
253253
"@xla//xla/tsl/python/lib/core:numpy",
254-
"@com_google_absl//absl/status",
255-
"@local_config_rocm//rocm:rocm_headers",
256-
"@nanobind",
257254
],
258255
)
259256

0 commit comments

Comments
 (0)