From b84d3150687c02787075ec313d8d60a507f4cc64 Mon Sep 17 00:00:00 2001 From: FindHao Date: Sun, 20 Jul 2025 17:14:50 -0400 Subject: [PATCH] Refactor trace content parsing and enhance source mapping extraction Summary: - Updated `parse_single_trace_content` to only process compilation events, improving efficiency by skipping non-compilation events. - Enhanced the handling of IR file mappings and bidirectional mappings between different IR types. - Improved the organization of output lines for final writing, ensuring accurate grouping of compilation and launch events by kernel hash. - Added new utility functions for flattening and unflattening dictionaries, and for generating launch diffs, enhancing the overall functionality of the module. These changes aim to streamline the parsing process and improve the clarity and maintainability of the codebase. --- tritonparse/extract_source_mappings.py | 474 ++++++++++++++++++------- tritonparse/structured_logging.py | 7 +- 2 files changed, 350 insertions(+), 131 deletions(-) diff --git a/tritonparse/extract_source_mappings.py b/tritonparse/extract_source_mappings.py index 194f4c9..a3fd972 100644 --- a/tritonparse/extract_source_mappings.py +++ b/tritonparse/extract_source_mappings.py @@ -449,80 +449,80 @@ def parse_single_trace_content(trace_content: str) -> str: """ entry = json.loads(trace_content) - if entry.get("event_type") != "compilation" and "payload" in entry: - logger.warning("Not a compilation event. Skipping.") - return "" - payload = entry.setdefault("payload", {}) - file_content = payload.get("file_content", {}) - file_path = payload.get("file_path", {}) - - # Find the IR file keys - ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) - ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) - ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) - amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) - # Skip if no IR files found - if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): - logger.warning("No IR files found in the payload.") - return trace_content - - # generate ttir->source, ttgir->source, ptx->source - ttir_map = process_ir(ttir_key, file_content, file_path) - ttgir_map = process_ir(ttgir_key, file_content, file_path) - ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) - amdgcn_map = process_ir(amdgcn_key, file_content, file_path, [ttir_map, ttgir_map]) - - # Create bidirectional mappings between all IR types - ir_maps = { - "ttir": ttir_map, - "ttgir": ttgir_map, - "ptx": ptx_map, - "amdgcn": amdgcn_map, - } - - # Create mappings between all pairs of IR types - ir_types = list(ir_maps.keys()) - for i, src_type in enumerate(ir_types): - for tgt_type in ir_types[i + 1 :]: - if ir_maps[src_type] and ir_maps[tgt_type]: - create_bidirectional_mapping( - ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type - ) - logger.debug( - f"Created bidirectional mapping between {src_type} and {tgt_type}" - ) - - py_map = {} - - if "python_source" in payload: - logger.debug( - f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + if entry.get("event_type") == "compilation": + payload = entry.setdefault("payload", {}) + file_content = payload.get("file_content", {}) + file_path = payload.get("file_path", {}) + + # Find the IR file keys + ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) + ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) + ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) + amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) + # Skip if no IR files found + if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): + logger.warning("No IR files found in the payload.") + return trace_content + + # generate ttir->source, ttgir->source, ptx->source + ttir_map = process_ir(ttir_key, file_content, file_path) + ttgir_map = process_ir(ttgir_key, file_content, file_path) + ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) + amdgcn_map = process_ir( + amdgcn_key, file_content, file_path, [ttir_map, ttgir_map] ) - # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. - # Create a list of valid IR mappings, filtering out None keys - ir_mappings = [] - ir_keys_and_maps = [ - (ttir_key, ttir_map), - (ttgir_key, ttgir_map), - (ptx_key, ptx_map), - (amdgcn_key, amdgcn_map), - ] - - for key, mapping in ir_keys_and_maps: - if key: - ir_mappings.append((get_file_extension(key), mapping)) - - py_map = create_python_mapping(ir_mappings) - - # Store the mappings in the payload - payload["source_mappings"] = { - "ttir": ttir_map, - "ttgir": ttgir_map, - **({"ptx": ptx_map} if ptx_map else {}), - **({"amdgcn": amdgcn_map} if amdgcn_map else {}), - "python": py_map, - } + # Create bidirectional mappings between all IR types + ir_maps = { + "ttir": ttir_map, + "ttgir": ttgir_map, + "ptx": ptx_map, + "amdgcn": amdgcn_map, + } + + # Create mappings between all pairs of IR types + ir_types = list(ir_maps.keys()) + for i, src_type in enumerate(ir_types): + for tgt_type in ir_types[i + 1 :]: + if ir_maps[src_type] and ir_maps[tgt_type]: + create_bidirectional_mapping( + ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type + ) + logger.debug( + f"Created bidirectional mapping between {src_type} and {tgt_type}" + ) + + py_map = {} + + if "python_source" in payload: + logger.debug( + f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + ) + + # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. + # Create a list of valid IR mappings, filtering out None keys + ir_mappings = [] + ir_keys_and_maps = [ + (ttir_key, ttir_map), + (ttgir_key, ttgir_map), + (ptx_key, ptx_map), + (amdgcn_key, amdgcn_map), + ] + + for key, mapping in ir_keys_and_maps: + if key: + ir_mappings.append((get_file_extension(key), mapping)) + + py_map = create_python_mapping(ir_mappings) + + # Store the mappings in the payload + payload["source_mappings"] = { + "ttir": ttir_map, + "ttgir": ttgir_map, + **({"ptx": ptx_map} if ptx_map else {}), + **({"amdgcn": amdgcn_map} if amdgcn_map else {}), + "python": py_map, + } # NDJSON format requires a newline at the end of each line return json.dumps(entry, separators=(",", ":")) + "\n" @@ -533,90 +533,134 @@ def parse_single_file( split_by_frame_id_and_compile_id: bool = True, ): """ - Process a single file and extract source code mappings. + Process a single file, correctly group events by kernel, and extract mappings. - This function takes a file path as input, reads the file content, and extracts - source code mappings from Triton trace JSON files. It processes each line of the file, - parses the trace content to extract IR mappings, and writes the updated content - to output files. + This function reads a trace file, groups compilation and launch events by + their kernel hash, generates a launch_diff event for each kernel, and writes + the processed data to output files. Args: file_path (str): The path to the file to be processed. - output_dir (str, optional): Directory to save the output files with mappings. - split_by_frame_id_and_compile_id (bool, optional): Whether to split output files - by frame_id and compile_id. Defaults to True. - - Returns: - None. The function writes the processed data to files in the output_dir. - Each output file will contain the original trace data enriched with source mappings - in NDJSON format (one JSON object per line). + output_dir (str, optional): Directory to save the output files. + split_by_frame_id_and_compile_id (bool, optional): Whether to split + output files by frame_id and compile_id. Defaults to True. """ - outputs = defaultdict(list) + kernels_by_hash = defaultdict( + lambda: {"compilation": None, "launches": [], "output_file": None} + ) - # Set default output directory if not provided output_dir = output_dir or os.path.dirname(file_path) - - # Check if input file is compressed based on file extension is_compressed_input = file_path.endswith(".bin.ndjson") - - # Open file in appropriate mode - use gzip.open for compressed files - if is_compressed_input: - # Use gzip.open which automatically handles member concatenation - file_handle = gzip.open(file_path, "rt", encoding="utf-8") - else: - file_handle = open(file_path, "r") + file_handle = ( + gzip.open(file_path, "rt", encoding="utf-8") + if is_compressed_input + else open(file_path, "r") + ) with file_handle as f: file_name = os.path.basename(file_path) - # Handle .bin.ndjson extension properly - if is_compressed_input: - file_name_without_extension = file_name[:-11] # Remove .bin.ndjson - else: - file_name_without_extension = os.path.splitext(file_name)[0] + file_name_without_extension = ( + file_name[:-11] if is_compressed_input else os.path.splitext(file_name)[0] + ) - # Process lines uniformly for both compressed and uncompressed files for i, line in enumerate(f): logger.debug(f"Processing line {i + 1} in {file_path}") - json_str = line.strip() if not json_str: continue - parsed_line = parse_single_trace_content(json_str) - if not parsed_line: - logger.warning(f"Failed to parse line {i + 1} in {file_path}") + # We don't need to generate full mappings for every line here, + # just enough to get the event type and necessary IDs. + try: + parsed_json = json.loads(json_str) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON on line {i + 1} in {file_path}") continue - parsed_json = json.loads(parsed_line) - payload = parsed_json.get("payload", None) - if split_by_frame_id_and_compile_id: - if not payload: - logger.warning("No payload found in the parsed JSON.") + event_type = parsed_json.get("event_type", None) + payload = parsed_json.get("payload", {}) + + if event_type == "compilation": + kernel_hash = payload.get("metadata", {}).get("hash") + if not kernel_hash: continue - pt_info = payload.get("pt_info", {}) - frame_id = pt_info.get("frame_id", None) - frame_compile_id = pt_info.get("frame_compile_id", None) - compiled_autograd_id = pt_info.get("compiled_autograd_id", "-") - attempt_id = pt_info.get("attempt_id", 0) - output_file_name = "" - if frame_id is not None or frame_compile_id is not None: - output_file_name = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{compiled_autograd_id}.ndjson" + + if split_by_frame_id_and_compile_id: + pt_info = payload.get("pt_info", {}) + frame_id = pt_info.get("frame_id") + frame_compile_id = pt_info.get("frame_compile_id") + attempt_id = pt_info.get("attempt_id", 0) + cai = pt_info.get("compiled_autograd_id", "-") + if frame_id is not None or frame_compile_id is not None: + fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" + else: + fname = f"{file_name_without_extension}_mapped.ndjson" else: - logger.debug( - "No frame_id or frame_compile_id found in the payload." + fname = f"{file_name_without_extension}_mapped.ndjson" + + output_file = os.path.join(output_dir, fname) + # The full processing is deferred until the final write. + kernels_by_hash[kernel_hash]["compilation"] = json_str + kernels_by_hash[kernel_hash]["output_file"] = output_file + + elif event_type == "launch": + kernel_hash = parsed_json.get("compilation_metadata", {}).get("hash") + if kernel_hash: + kernels_by_hash[kernel_hash]["launches"].append( + (parsed_json, i + 1) ) - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - else: - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - output_file = os.path.join(output_dir, output_file_name) - outputs[output_file].append(parsed_line) - logger.debug(f"output file: {output_file}") + + # Organize lines for final output, keyed by output file path + all_output_lines = defaultdict(list) + for _kernel_hash, data in kernels_by_hash.items(): + compilation_json_str = data["compilation"] + launches_with_indices = data["launches"] + output_file = data["output_file"] + + if not output_file: + logger.warning(f"No output file for kernel hash {_kernel_hash}, skipping.") + continue + + # Process the compilation event now to include source mappings + if compilation_json_str: + processed_compilation_line = parse_single_trace_content( + compilation_json_str + ) + all_output_lines[output_file].append(processed_compilation_line) + compilation_event = json.loads(processed_compilation_line) + else: + compilation_event = None + + for launch_event, _ in launches_with_indices: + all_output_lines[output_file].append( + json.dumps(launch_event, separators=(",", ":")) + "\n" + ) + + if compilation_event and launches_with_indices: + sames, diffs, launch_index_map = _generate_launch_diff( + launches_with_indices + ) + launch_diff_event = { + "event_type": "launch_diff", + "hash": _kernel_hash, + "name": compilation_event.get("payload", {}) + .get("metadata", {}) + .get("name"), + "total_launches": len(launches_with_indices), + "launch_index_map": launch_index_map, + "diffs": diffs, + "sames": sames, + } + all_output_lines[output_file].append( + json.dumps(launch_diff_event, separators=(",", ":")) + "\n" + ) if not os.path.exists(output_dir): os.makedirs(output_dir) - for output_file, parsed_lines in outputs.items(): + + for output_file, final_lines in all_output_lines.items(): with open(output_file, "w") as out: - out.writelines(parsed_lines) + out.writelines(final_lines) def parse_args(): @@ -638,6 +682,178 @@ def parse_args(): return parser.parse_args() +# Fields that are expected to vary but are not useful to list out in the diff. +SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"] + + +def _flatten_dict( + d: Dict[str, Any], parent_key: str = "", sep: str = "." +) -> Dict[str, Any]: + """ + Flattens a nested dictionary. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _unflatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """ + Unflattens a dictionary with delimited keys. + """ + result = {} + for key, value in d.items(): + parts = key.split(sep) + d_ref = result + for part in parts[:-1]: + if part not in d_ref: + d_ref[part] = {} + d_ref = d_ref[part] + d_ref[parts[-1]] = value + return result + + +def _to_ranges(indices: List[int]) -> List[Dict[str, int]]: + """ + Converts a sorted list of indices into a list of continuous ranges. + e.g., [0, 1, 2, 5, 6, 8] -> [{'start': 0, 'end': 2}, {'start': 5, 'end': 6}, {'start': 8, 'end': 8}] + """ + if not indices: + return [] + + indices = sorted(indices) + ranges = [] + start = indices[0] + end = indices[0] + + for i in range(1, len(indices)): + if indices[i] == end + 1: + end = indices[i] + else: + ranges.append({"start": start, "end": end}) + start = end = indices[i] + + ranges.append({"start": start, "end": end}) + return ranges + + +def _generate_launch_diff( + launches: List[Tuple[Dict[str, Any], int]], +) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]: + """ + Compares a list of launch events and returns sames, diffs, and an index map. + """ + if not launches: + return {}, {}, [] + + launch_events = [launch[0] for launch in launches] + launch_index_map = [launch[1] for launch in launches] + + if len(launch_events) == 1: + return ( + _unflatten_dict(_flatten_dict(launch_events[0])), + {}, + _to_ranges(launch_index_map), + ) + + # Group values by key + data_by_key = defaultdict(lambda: defaultdict(list)) + for i, launch in enumerate(launch_events): + launch_flat = _flatten_dict(launch) + for key, value in launch_flat.items(): + # JSON doesn't support all Python types as values directly, str is safer + value_str = json.dumps(value, sort_keys=True) + data_by_key[key][value_str].append(i) + + sames_flat = {} + diffs_flat = {} + + for key, value_groups in data_by_key.items(): + if len(value_groups) == 1: + # This key has the same value across all launches + value_str = list(value_groups.keys())[0] + sames_flat[key] = json.loads(value_str) + else: + # This key has different values + is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS) + if is_summary: + diffs_flat[key] = { + "diff_type": "summary", + "summary_text": f"Varies across {len(value_groups)} unique values", + } + else: + values_dist = [] + for value_str, indices in value_groups.items(): + values_dist.append( + { + "value": json.loads(value_str), + "count": len(indices), + "launches": _to_ranges(indices), + } + ) + # Sort by first occurrence + values_dist.sort(key=lambda x: x["launches"][0]["start"]) + diffs_flat[key] = { + "diff_type": "distribution", + "values": values_dist, + } + + # Unflatten the results + sames_unflattened = _unflatten_dict(sames_flat) + diffs_unflattened = _unflatten_dict(diffs_flat) + + # Special handling for extracted_args to create argument_diff structures + if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened: + sames_args = sames_unflattened.pop("extracted_args", {}) + diffs_args_flat = diffs_unflattened.pop("extracted_args", {}) + + all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys()) + + final_arg_diffs = {} + + for arg_name in all_arg_names: + if arg_name in diffs_args_flat: + # This argument has at least one differing sub-field. + arg_sames = {} + arg_diffs_internal = {} + + # Collect all sub-fields for this argument from the original data + all_sub_fields = set() + for launch in launch_events: + arg_data = launch.get("extracted_args", {}).get(arg_name, {}) + all_sub_fields.update(arg_data.keys()) + + for sub_field in all_sub_fields: + flat_key = f"extracted_args.{arg_name}.{sub_field}" + if flat_key in diffs_flat: + arg_diffs_internal[sub_field] = diffs_flat[flat_key] + elif flat_key in sames_flat: + arg_sames[sub_field] = sames_flat[flat_key] + + if arg_sames or arg_diffs_internal: + final_arg_diffs[arg_name] = { + "diff_type": "argument_diff", + "sames": arg_sames, + "diffs": arg_diffs_internal, + } + elif arg_name in sames_args: + # This argument is entirely the same across all launches. + # We move it back to the main sames dict for consistency. + if "extracted_args" not in sames_unflattened: + sames_unflattened["extracted_args"] = {} + sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name] + + if final_arg_diffs: + diffs_unflattened["extracted_args"] = final_arg_diffs + + return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map) + + if __name__ == "__main__": args = parse_args() if args.input: diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 76c6b6d..81711f9 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -156,6 +156,9 @@ def convert(obj): # 2. simple containers ---------------------------------------------------- if isinstance(obj, (list, tuple)): + # Handle namedtuple specially to preserve field names + if hasattr(obj, "_asdict"): + return convert(obj._asdict()) return [convert(x) for x in obj] if isinstance(obj, (set, frozenset)): @@ -810,7 +813,7 @@ def extract_arg_info(arg_dict): def add_launch_metadata(grid, metadata, arg_dict): # Extract detailed argument information extracted_args = extract_arg_info(arg_dict) - return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)} class JITHookImpl(JITHook): @@ -928,7 +931,7 @@ def __call__(self, metadata): ) if launch_metadata_tritonparse is not None: trace_data["grid"] = launch_metadata_tritonparse[0] - trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["compilation_metadata"] = launch_metadata_tritonparse[1] trace_data["extracted_args"] = launch_metadata_tritonparse[ 2 ] # Now contains detailed arg info