diff --git a/tritonparse/extract_source_mappings.py b/tritonparse/extract_source_mappings.py index 194f4c9..a3fd972 100644 --- a/tritonparse/extract_source_mappings.py +++ b/tritonparse/extract_source_mappings.py @@ -449,80 +449,80 @@ def parse_single_trace_content(trace_content: str) -> str: """ entry = json.loads(trace_content) - if entry.get("event_type") != "compilation" and "payload" in entry: - logger.warning("Not a compilation event. Skipping.") - return "" - payload = entry.setdefault("payload", {}) - file_content = payload.get("file_content", {}) - file_path = payload.get("file_path", {}) - - # Find the IR file keys - ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) - ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) - ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) - amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) - # Skip if no IR files found - if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): - logger.warning("No IR files found in the payload.") - return trace_content - - # generate ttir->source, ttgir->source, ptx->source - ttir_map = process_ir(ttir_key, file_content, file_path) - ttgir_map = process_ir(ttgir_key, file_content, file_path) - ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) - amdgcn_map = process_ir(amdgcn_key, file_content, file_path, [ttir_map, ttgir_map]) - - # Create bidirectional mappings between all IR types - ir_maps = { - "ttir": ttir_map, - "ttgir": ttgir_map, - "ptx": ptx_map, - "amdgcn": amdgcn_map, - } - - # Create mappings between all pairs of IR types - ir_types = list(ir_maps.keys()) - for i, src_type in enumerate(ir_types): - for tgt_type in ir_types[i + 1 :]: - if ir_maps[src_type] and ir_maps[tgt_type]: - create_bidirectional_mapping( - ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type - ) - logger.debug( - f"Created bidirectional mapping between {src_type} and {tgt_type}" - ) - - py_map = {} - - if "python_source" in payload: - logger.debug( - f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + if entry.get("event_type") == "compilation": + payload = entry.setdefault("payload", {}) + file_content = payload.get("file_content", {}) + file_path = payload.get("file_path", {}) + + # Find the IR file keys + ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) + ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) + ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) + amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) + # Skip if no IR files found + if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): + logger.warning("No IR files found in the payload.") + return trace_content + + # generate ttir->source, ttgir->source, ptx->source + ttir_map = process_ir(ttir_key, file_content, file_path) + ttgir_map = process_ir(ttgir_key, file_content, file_path) + ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) + amdgcn_map = process_ir( + amdgcn_key, file_content, file_path, [ttir_map, ttgir_map] ) - # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. - # Create a list of valid IR mappings, filtering out None keys - ir_mappings = [] - ir_keys_and_maps = [ - (ttir_key, ttir_map), - (ttgir_key, ttgir_map), - (ptx_key, ptx_map), - (amdgcn_key, amdgcn_map), - ] - - for key, mapping in ir_keys_and_maps: - if key: - ir_mappings.append((get_file_extension(key), mapping)) - - py_map = create_python_mapping(ir_mappings) - - # Store the mappings in the payload - payload["source_mappings"] = { - "ttir": ttir_map, - "ttgir": ttgir_map, - **({"ptx": ptx_map} if ptx_map else {}), - **({"amdgcn": amdgcn_map} if amdgcn_map else {}), - "python": py_map, - } + # Create bidirectional mappings between all IR types + ir_maps = { + "ttir": ttir_map, + "ttgir": ttgir_map, + "ptx": ptx_map, + "amdgcn": amdgcn_map, + } + + # Create mappings between all pairs of IR types + ir_types = list(ir_maps.keys()) + for i, src_type in enumerate(ir_types): + for tgt_type in ir_types[i + 1 :]: + if ir_maps[src_type] and ir_maps[tgt_type]: + create_bidirectional_mapping( + ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type + ) + logger.debug( + f"Created bidirectional mapping between {src_type} and {tgt_type}" + ) + + py_map = {} + + if "python_source" in payload: + logger.debug( + f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + ) + + # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. + # Create a list of valid IR mappings, filtering out None keys + ir_mappings = [] + ir_keys_and_maps = [ + (ttir_key, ttir_map), + (ttgir_key, ttgir_map), + (ptx_key, ptx_map), + (amdgcn_key, amdgcn_map), + ] + + for key, mapping in ir_keys_and_maps: + if key: + ir_mappings.append((get_file_extension(key), mapping)) + + py_map = create_python_mapping(ir_mappings) + + # Store the mappings in the payload + payload["source_mappings"] = { + "ttir": ttir_map, + "ttgir": ttgir_map, + **({"ptx": ptx_map} if ptx_map else {}), + **({"amdgcn": amdgcn_map} if amdgcn_map else {}), + "python": py_map, + } # NDJSON format requires a newline at the end of each line return json.dumps(entry, separators=(",", ":")) + "\n" @@ -533,90 +533,134 @@ def parse_single_file( split_by_frame_id_and_compile_id: bool = True, ): """ - Process a single file and extract source code mappings. + Process a single file, correctly group events by kernel, and extract mappings. - This function takes a file path as input, reads the file content, and extracts - source code mappings from Triton trace JSON files. It processes each line of the file, - parses the trace content to extract IR mappings, and writes the updated content - to output files. + This function reads a trace file, groups compilation and launch events by + their kernel hash, generates a launch_diff event for each kernel, and writes + the processed data to output files. Args: file_path (str): The path to the file to be processed. - output_dir (str, optional): Directory to save the output files with mappings. - split_by_frame_id_and_compile_id (bool, optional): Whether to split output files - by frame_id and compile_id. Defaults to True. - - Returns: - None. The function writes the processed data to files in the output_dir. - Each output file will contain the original trace data enriched with source mappings - in NDJSON format (one JSON object per line). + output_dir (str, optional): Directory to save the output files. + split_by_frame_id_and_compile_id (bool, optional): Whether to split + output files by frame_id and compile_id. Defaults to True. """ - outputs = defaultdict(list) + kernels_by_hash = defaultdict( + lambda: {"compilation": None, "launches": [], "output_file": None} + ) - # Set default output directory if not provided output_dir = output_dir or os.path.dirname(file_path) - - # Check if input file is compressed based on file extension is_compressed_input = file_path.endswith(".bin.ndjson") - - # Open file in appropriate mode - use gzip.open for compressed files - if is_compressed_input: - # Use gzip.open which automatically handles member concatenation - file_handle = gzip.open(file_path, "rt", encoding="utf-8") - else: - file_handle = open(file_path, "r") + file_handle = ( + gzip.open(file_path, "rt", encoding="utf-8") + if is_compressed_input + else open(file_path, "r") + ) with file_handle as f: file_name = os.path.basename(file_path) - # Handle .bin.ndjson extension properly - if is_compressed_input: - file_name_without_extension = file_name[:-11] # Remove .bin.ndjson - else: - file_name_without_extension = os.path.splitext(file_name)[0] + file_name_without_extension = ( + file_name[:-11] if is_compressed_input else os.path.splitext(file_name)[0] + ) - # Process lines uniformly for both compressed and uncompressed files for i, line in enumerate(f): logger.debug(f"Processing line {i + 1} in {file_path}") - json_str = line.strip() if not json_str: continue - parsed_line = parse_single_trace_content(json_str) - if not parsed_line: - logger.warning(f"Failed to parse line {i + 1} in {file_path}") + # We don't need to generate full mappings for every line here, + # just enough to get the event type and necessary IDs. + try: + parsed_json = json.loads(json_str) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON on line {i + 1} in {file_path}") continue - parsed_json = json.loads(parsed_line) - payload = parsed_json.get("payload", None) - if split_by_frame_id_and_compile_id: - if not payload: - logger.warning("No payload found in the parsed JSON.") + event_type = parsed_json.get("event_type", None) + payload = parsed_json.get("payload", {}) + + if event_type == "compilation": + kernel_hash = payload.get("metadata", {}).get("hash") + if not kernel_hash: continue - pt_info = payload.get("pt_info", {}) - frame_id = pt_info.get("frame_id", None) - frame_compile_id = pt_info.get("frame_compile_id", None) - compiled_autograd_id = pt_info.get("compiled_autograd_id", "-") - attempt_id = pt_info.get("attempt_id", 0) - output_file_name = "" - if frame_id is not None or frame_compile_id is not None: - output_file_name = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{compiled_autograd_id}.ndjson" + + if split_by_frame_id_and_compile_id: + pt_info = payload.get("pt_info", {}) + frame_id = pt_info.get("frame_id") + frame_compile_id = pt_info.get("frame_compile_id") + attempt_id = pt_info.get("attempt_id", 0) + cai = pt_info.get("compiled_autograd_id", "-") + if frame_id is not None or frame_compile_id is not None: + fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" + else: + fname = f"{file_name_without_extension}_mapped.ndjson" else: - logger.debug( - "No frame_id or frame_compile_id found in the payload." + fname = f"{file_name_without_extension}_mapped.ndjson" + + output_file = os.path.join(output_dir, fname) + # The full processing is deferred until the final write. + kernels_by_hash[kernel_hash]["compilation"] = json_str + kernels_by_hash[kernel_hash]["output_file"] = output_file + + elif event_type == "launch": + kernel_hash = parsed_json.get("compilation_metadata", {}).get("hash") + if kernel_hash: + kernels_by_hash[kernel_hash]["launches"].append( + (parsed_json, i + 1) ) - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - else: - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - output_file = os.path.join(output_dir, output_file_name) - outputs[output_file].append(parsed_line) - logger.debug(f"output file: {output_file}") + + # Organize lines for final output, keyed by output file path + all_output_lines = defaultdict(list) + for _kernel_hash, data in kernels_by_hash.items(): + compilation_json_str = data["compilation"] + launches_with_indices = data["launches"] + output_file = data["output_file"] + + if not output_file: + logger.warning(f"No output file for kernel hash {_kernel_hash}, skipping.") + continue + + # Process the compilation event now to include source mappings + if compilation_json_str: + processed_compilation_line = parse_single_trace_content( + compilation_json_str + ) + all_output_lines[output_file].append(processed_compilation_line) + compilation_event = json.loads(processed_compilation_line) + else: + compilation_event = None + + for launch_event, _ in launches_with_indices: + all_output_lines[output_file].append( + json.dumps(launch_event, separators=(",", ":")) + "\n" + ) + + if compilation_event and launches_with_indices: + sames, diffs, launch_index_map = _generate_launch_diff( + launches_with_indices + ) + launch_diff_event = { + "event_type": "launch_diff", + "hash": _kernel_hash, + "name": compilation_event.get("payload", {}) + .get("metadata", {}) + .get("name"), + "total_launches": len(launches_with_indices), + "launch_index_map": launch_index_map, + "diffs": diffs, + "sames": sames, + } + all_output_lines[output_file].append( + json.dumps(launch_diff_event, separators=(",", ":")) + "\n" + ) if not os.path.exists(output_dir): os.makedirs(output_dir) - for output_file, parsed_lines in outputs.items(): + + for output_file, final_lines in all_output_lines.items(): with open(output_file, "w") as out: - out.writelines(parsed_lines) + out.writelines(final_lines) def parse_args(): @@ -638,6 +682,178 @@ def parse_args(): return parser.parse_args() +# Fields that are expected to vary but are not useful to list out in the diff. +SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"] + + +def _flatten_dict( + d: Dict[str, Any], parent_key: str = "", sep: str = "." +) -> Dict[str, Any]: + """ + Flattens a nested dictionary. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _unflatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """ + Unflattens a dictionary with delimited keys. + """ + result = {} + for key, value in d.items(): + parts = key.split(sep) + d_ref = result + for part in parts[:-1]: + if part not in d_ref: + d_ref[part] = {} + d_ref = d_ref[part] + d_ref[parts[-1]] = value + return result + + +def _to_ranges(indices: List[int]) -> List[Dict[str, int]]: + """ + Converts a sorted list of indices into a list of continuous ranges. + e.g., [0, 1, 2, 5, 6, 8] -> [{'start': 0, 'end': 2}, {'start': 5, 'end': 6}, {'start': 8, 'end': 8}] + """ + if not indices: + return [] + + indices = sorted(indices) + ranges = [] + start = indices[0] + end = indices[0] + + for i in range(1, len(indices)): + if indices[i] == end + 1: + end = indices[i] + else: + ranges.append({"start": start, "end": end}) + start = end = indices[i] + + ranges.append({"start": start, "end": end}) + return ranges + + +def _generate_launch_diff( + launches: List[Tuple[Dict[str, Any], int]], +) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]: + """ + Compares a list of launch events and returns sames, diffs, and an index map. + """ + if not launches: + return {}, {}, [] + + launch_events = [launch[0] for launch in launches] + launch_index_map = [launch[1] for launch in launches] + + if len(launch_events) == 1: + return ( + _unflatten_dict(_flatten_dict(launch_events[0])), + {}, + _to_ranges(launch_index_map), + ) + + # Group values by key + data_by_key = defaultdict(lambda: defaultdict(list)) + for i, launch in enumerate(launch_events): + launch_flat = _flatten_dict(launch) + for key, value in launch_flat.items(): + # JSON doesn't support all Python types as values directly, str is safer + value_str = json.dumps(value, sort_keys=True) + data_by_key[key][value_str].append(i) + + sames_flat = {} + diffs_flat = {} + + for key, value_groups in data_by_key.items(): + if len(value_groups) == 1: + # This key has the same value across all launches + value_str = list(value_groups.keys())[0] + sames_flat[key] = json.loads(value_str) + else: + # This key has different values + is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS) + if is_summary: + diffs_flat[key] = { + "diff_type": "summary", + "summary_text": f"Varies across {len(value_groups)} unique values", + } + else: + values_dist = [] + for value_str, indices in value_groups.items(): + values_dist.append( + { + "value": json.loads(value_str), + "count": len(indices), + "launches": _to_ranges(indices), + } + ) + # Sort by first occurrence + values_dist.sort(key=lambda x: x["launches"][0]["start"]) + diffs_flat[key] = { + "diff_type": "distribution", + "values": values_dist, + } + + # Unflatten the results + sames_unflattened = _unflatten_dict(sames_flat) + diffs_unflattened = _unflatten_dict(diffs_flat) + + # Special handling for extracted_args to create argument_diff structures + if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened: + sames_args = sames_unflattened.pop("extracted_args", {}) + diffs_args_flat = diffs_unflattened.pop("extracted_args", {}) + + all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys()) + + final_arg_diffs = {} + + for arg_name in all_arg_names: + if arg_name in diffs_args_flat: + # This argument has at least one differing sub-field. + arg_sames = {} + arg_diffs_internal = {} + + # Collect all sub-fields for this argument from the original data + all_sub_fields = set() + for launch in launch_events: + arg_data = launch.get("extracted_args", {}).get(arg_name, {}) + all_sub_fields.update(arg_data.keys()) + + for sub_field in all_sub_fields: + flat_key = f"extracted_args.{arg_name}.{sub_field}" + if flat_key in diffs_flat: + arg_diffs_internal[sub_field] = diffs_flat[flat_key] + elif flat_key in sames_flat: + arg_sames[sub_field] = sames_flat[flat_key] + + if arg_sames or arg_diffs_internal: + final_arg_diffs[arg_name] = { + "diff_type": "argument_diff", + "sames": arg_sames, + "diffs": arg_diffs_internal, + } + elif arg_name in sames_args: + # This argument is entirely the same across all launches. + # We move it back to the main sames dict for consistency. + if "extracted_args" not in sames_unflattened: + sames_unflattened["extracted_args"] = {} + sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name] + + if final_arg_diffs: + diffs_unflattened["extracted_args"] = final_arg_diffs + + return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map) + + if __name__ == "__main__": args = parse_args() if args.input: diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 76c6b6d..81711f9 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -156,6 +156,9 @@ def convert(obj): # 2. simple containers ---------------------------------------------------- if isinstance(obj, (list, tuple)): + # Handle namedtuple specially to preserve field names + if hasattr(obj, "_asdict"): + return convert(obj._asdict()) return [convert(x) for x in obj] if isinstance(obj, (set, frozenset)): @@ -810,7 +813,7 @@ def extract_arg_info(arg_dict): def add_launch_metadata(grid, metadata, arg_dict): # Extract detailed argument information extracted_args = extract_arg_info(arg_dict) - return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)} class JITHookImpl(JITHook): @@ -928,7 +931,7 @@ def __call__(self, metadata): ) if launch_metadata_tritonparse is not None: trace_data["grid"] = launch_metadata_tritonparse[0] - trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["compilation_metadata"] = launch_metadata_tritonparse[1] trace_data["extracted_args"] = launch_metadata_tritonparse[ 2 ] # Now contains detailed arg info