Skip to content

Commit b1f70ba

Browse files
FindHaofacebook-github-bot
authored andcommitted
Refactor: Modularize extract_source_mappings.py for improved maintainability (#44)
Summary: The `tritonparse/tritonparse/extract_source_mappings.py` script was a large, monolithic file containing logic for IR parsing, source mapping, event diffing, and file processing. This made it difficult to read, maintain, and extend. This pull request refactors the script by breaking it down into smaller, single-responsibility modules. This improves the overall code structure, enhances readability, and makes future development easier. ### Description of Changes The core logic of `extract_source_mappings.py` has been split into several new modules within the `tritonparse/tritonparse/` directory: * **`extract_source_mappings.py` (Modified):** Now serves as a clean command-line entry point, containing only argument parsing and the main execution call. * **`trace_processor.py` (New):** The main orchestrator that handles the processing of trace files. It contains the core logic previously found in `parse_single_file` and `parse_single_trace_content`. * **`ir_parser.py` (New):** Contains all functions related to parsing Intermediate Representations (TTIR, TTGIR, PTX, AMDGCN), including `loc` directive extraction. * **`mapper.py` (New):** Responsible for creating the bidirectional source mappings between Python source and various IRs. * **`event_diff.py` (New):** Provides functionality to compare a list of events and generate a summary of their differences. * **`sourcemap_utils.py` (New):** A collection of general-purpose utility functions (e.g., dictionary flattening) used by the new modules. Pull Request resolved: #44 Test Plan: ```bash % python -m unittest tests.test_tritonparse -v test_convert (tests.test_tritonparse.TestTritonparseCPU.test_convert) Test convert function with various data types ... ok test_complex_kernels (tests.test_tritonparse.TestTritonparseCUDA.test_complex_kernels) A more complex test case involving two distinct Triton kernels, one of which uses autotuning. ... Temporary directory: /tmp/tmphy6i6cw8 --- Testing Matmul Kernel (3 launches) --- WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7a6893ca80e0>. It will be overridden by tritonparse. WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7a6893ca80e0>. It will be overridden by tritonparse. Matmul Launch 1 (16x16 @ 16x16) done. Matmul Launch 2 (32x16 @ 16x32) done. Matmul Launch 3 (16x32 @ 32x16) done. --- Testing Fused Op Kernel (4 launches) --- Fused Op Launch 1: scale=1.0, activation=None Fused Op Launch 2: scale=2.5, activation=None Fused Op Launch 3: scale=1.0, activation='relu' WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.fused_op_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.fused_op_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7a6893ca80e0>. It will be overridden by tritonparse. Fused Op Launch 4: scale=1.0, activation='relu', different size All kernels executed. tritonparse log file list: /tmp/tmpeo5py734/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmpeo5py734 to /tmp/tmphy6i6cw8/parsed_output_complex ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmphy6i6cw8/parsed_output_complex 📊 Total files generated: 2 📄 Generated files: -------------------------------------------------- 1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (170.3KB) 2. 📝 log_file_list.json (181B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ ✓ Generated 1 log files ✓ Generated 2 parsed files ✓ Found 1 .json files and 1 .ndjson.gz files Checking launch_diff events in dedicated_log_triton_trace_findhao__mapped.ndjson.gz Line 402: Found launch_diff event (count: 1) Line 897: Found launch_diff event (count: 2) Line 1384: Found launch_diff event (count: 3) Line 1388: Found launch_diff event (count: 4) Line 1392: Found launch_diff event (count: 5) ✓ Total launch_diff events found: 5 ✓ Verified 5 launch_diff events in parsed output ✓ Cleaned up temporary directory ok test_extract_python_source_info (tests.test_tritonparse.TestTritonparseCUDA.test_extract_python_source_info) Test extract_python_source_info function ... ok test_whole_workflow (tests.test_tritonparse.TestTritonparseCUDA.test_whole_workflow) Test unified_parse functionality ... Temporary directory: /tmp/tmpik7n136f Found 1 log files in /tmp/tmpik7n136f/logs: ['dedicated_log_triton_trace_findhao_.ndjson'] Line 1: event_type = 'compilation' (unique hash: 258d1b0a...) Line 2: event_type = 'launch' (count: 1) Line 3: event_type = 'launch' (count: 2) Event type counts: {'launch': 2, 'compilation': 1} (unique compilation hashes: 1) ✓ Verified correct event type counts: 1 unique compilation hash, 2 launch events tritonparse log file list: /tmp/tmp53z4ny7b/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmp53z4ny7b to /tmp/tmpik7n136f/parsed_output ================================================================================ 📁 TRITONPARSE PARSING RESULTS ================================================================================ 📂 Parsed files directory: /tmp/tmpik7n136f/parsed_output 📊 Total files generated: 2 📄 Generated files: -------------------------------------------------- 1. 📝 dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.9KB) 2. 📝 log_file_list.json (181B) ================================================================================ ✅ Parsing completed successfully! ================================================================================ ✓ Cleaned up temporary directory ok ---------------------------------------------------------------------- Ran 4 tests in 16.816s OK ``` Reviewed By: davidberard98 Differential Revision: D78982556 Pulled By: FindHao fbshipit-source-id: ef2bca8f7c77460a48199dafb887fbe0ebf1f044
1 parent 2274615 commit b1f70ba

File tree

6 files changed

+845
-815
lines changed

6 files changed

+845
-815
lines changed

tritonparse/event_diff.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import json
2+
from collections import defaultdict
3+
from typing import Any, Dict, List, Tuple
4+
5+
from .sourcemap_utils import _flatten_dict, _to_ranges, _unflatten_dict
6+
7+
# Fields that are expected to vary but are not useful to list out in the diff.
8+
SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"]
9+
10+
11+
def _generate_launch_diff(
12+
launches: List[Tuple[Dict[str, Any], int]],
13+
) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]:
14+
"""
15+
Compares a list of launch events and returns sames, diffs, and an index map.
16+
"""
17+
if not launches:
18+
return {}, {}, []
19+
20+
launch_events = [launch[0] for launch in launches]
21+
launch_index_map = [launch[1] for launch in launches]
22+
23+
if len(launch_events) == 1:
24+
return (
25+
_unflatten_dict(_flatten_dict(launch_events[0])),
26+
{},
27+
_to_ranges(launch_index_map),
28+
)
29+
30+
# Group values by key
31+
data_by_key = defaultdict(lambda: defaultdict(list))
32+
for i, launch in enumerate(launch_events):
33+
launch_flat = _flatten_dict(launch)
34+
for key, value in launch_flat.items():
35+
# JSON doesn't support all Python types as values directly, str is safer
36+
value_str = json.dumps(value, sort_keys=True)
37+
data_by_key[key][value_str].append(i)
38+
39+
sames_flat = {}
40+
diffs_flat = {}
41+
42+
for key, value_groups in data_by_key.items():
43+
if len(value_groups) == 1:
44+
# This key has the same value across all launches
45+
value_str = list(value_groups.keys())[0]
46+
sames_flat[key] = json.loads(value_str)
47+
else:
48+
# This key has different values
49+
is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS)
50+
if is_summary:
51+
diffs_flat[key] = {
52+
"diff_type": "summary",
53+
"summary_text": f"Varies across {len(value_groups)} unique values",
54+
}
55+
else:
56+
values_dist = []
57+
for value_str, indices in value_groups.items():
58+
values_dist.append(
59+
{
60+
"value": json.loads(value_str),
61+
"count": len(indices),
62+
"launches": _to_ranges(indices),
63+
}
64+
)
65+
# Sort by first occurrence
66+
values_dist.sort(key=lambda x: x["launches"][0]["start"])
67+
diffs_flat[key] = {
68+
"diff_type": "distribution",
69+
"values": values_dist,
70+
}
71+
72+
# Unflatten the results
73+
sames_unflattened = _unflatten_dict(sames_flat)
74+
diffs_unflattened = _unflatten_dict(diffs_flat)
75+
76+
# Special handling for extracted_args to create argument_diff structures
77+
if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened:
78+
sames_args = sames_unflattened.pop("extracted_args", {})
79+
diffs_args_flat = diffs_unflattened.pop("extracted_args", {})
80+
81+
all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys())
82+
83+
final_arg_diffs = {}
84+
85+
for arg_name in all_arg_names:
86+
if arg_name in diffs_args_flat:
87+
# This argument has at least one differing sub-field.
88+
arg_sames = {}
89+
arg_diffs_internal = {}
90+
91+
# Collect all sub-fields for this argument from the original data
92+
all_sub_fields = set()
93+
for launch in launch_events:
94+
arg_data = launch.get("extracted_args", {}).get(arg_name, {})
95+
all_sub_fields.update(arg_data.keys())
96+
97+
for sub_field in all_sub_fields:
98+
flat_key = f"extracted_args.{arg_name}.{sub_field}"
99+
if flat_key in diffs_flat:
100+
arg_diffs_internal[sub_field] = diffs_flat[flat_key]
101+
elif flat_key in sames_flat:
102+
arg_sames[sub_field] = sames_flat[flat_key]
103+
104+
if arg_sames or arg_diffs_internal:
105+
final_arg_diffs[arg_name] = {
106+
"diff_type": "argument_diff",
107+
"sames": arg_sames,
108+
"diffs": arg_diffs_internal,
109+
}
110+
elif arg_name in sames_args:
111+
# This argument is entirely the same across all launches.
112+
# We move it back to the main sames dict for consistency.
113+
if "extracted_args" not in sames_unflattened:
114+
sames_unflattened["extracted_args"] = {}
115+
sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name]
116+
117+
if final_arg_diffs:
118+
diffs_unflattened["extracted_args"] = final_arg_diffs
119+
120+
return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map)

0 commit comments

Comments
 (0)