diff --git a/tests/test_complex_kernels.py b/tests/test_complex_kernels.py new file mode 100644 index 0000000..389d8a4 --- /dev/null +++ b/tests/test_complex_kernels.py @@ -0,0 +1,256 @@ +""" +A more complex test case involving two distinct Triton kernels, one of which uses autotuning. +This test is designed to validate the launch_diff functionality with multiple, varied launches. + +Test Plan: +``` +TORCHINDUCTOR_FX_GRAPH_CACHE=0 TRITONPARSE_DEBUG=1 python tests/test_complex_kernels.py +``` +""" + +import os + +import torch +import triton +import triton.language as tl + +import tritonparse.structured_logging +import tritonparse.utils + +# Initialize logging +log_path = "./logs_complex" +tritonparse.structured_logging.init(log_path, enable_trace_launch=True) + +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0" + + +# Kernel 1: Autotuned Matmul (simplified configs for small scale) +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a, + b, + c, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size) + pid_n = (pid % num_pid_in_group) // group_size + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_block = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + b_block = tl.load( + b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator += tl.dot(a_block, b_block) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c_block = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_block, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=None, + ) + return c + + +# Kernel 2: Fused element-wise operation +@triton.jit +def fused_op_kernel( + a_ptr, + b_ptr, + c_ptr, + output_ptr, + n_elements, + scale_factor: float, + ACTIVATION: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + c = tl.load(c_ptr + offsets, mask=mask) + + result = a * b * scale_factor + c + if ACTIVATION == "relu": + result = tl.where(result > 0, result, 0.0) + + tl.store(output_ptr + offsets, result, mask=mask) + + +def fused_op(a, b, c, scale_factor: float, activation: str): + n_elements = a.numel() + output = torch.empty_like(a) + BLOCK_SIZE = 8 # Reduced from 1024 for small scale testing + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + fused_op_kernel[grid]( + a, + b, + c, + output, + n_elements, + scale_factor, + ACTIVATION=activation, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +def test_complex_kernels(): + """Main test function to run both kernels with varied parameters.""" + torch.manual_seed(0) + + # --- Matmul Launches (3 times with different configs) --- + print("--- Testing Matmul Kernel (3 launches) ---") + # Launch 1 + a1 = torch.randn((16, 16), device="cuda", dtype=torch.float16) + b1 = torch.randn((16, 16), device="cuda", dtype=torch.float16) + c1 = matmul(a1, b1) + c1.sum() # Synchronize + print("Matmul Launch 1 (16x16 @ 16x16) done.") + + # Launch 2 + a2 = torch.randn((32, 16), device="cuda", dtype=torch.float16) + b2 = torch.randn((16, 32), device="cuda", dtype=torch.float16) + c2 = matmul(a2, b2) + c2.sum() # Synchronize + print("Matmul Launch 2 (32x16 @ 16x32) done.") + + # Launch 3 + a3 = torch.randn((16, 32), device="cuda", dtype=torch.float16) + b3 = torch.randn((32, 16), device="cuda", dtype=torch.float16) + c3 = matmul(a3, b3) + c3.sum() # Synchronize + print("Matmul Launch 3 (16x32 @ 32x16) done.") + + # --- Fused Op Launches (4 times with different parameters) --- + print("\n--- Testing Fused Op Kernel (4 launches) ---") + x = torch.randn((8,), device="cuda", dtype=torch.float32) + y = torch.randn((8,), device="cuda", dtype=torch.float32) + z = torch.randn((8,), device="cuda", dtype=torch.float32) + + # Launch 1 + print("Fused Op Launch 1: scale=1.0, activation=None") + out1 = fused_op(x, y, z, scale_factor=1.0, activation="none") + out1.sum() # Synchronize + + # Launch 2 + print("Fused Op Launch 2: scale=2.5, activation=None") + out2 = fused_op(x, y, z, scale_factor=2.5, activation="none") + out2.sum() # Synchronize + + # Launch 3 + print("Fused Op Launch 3: scale=1.0, activation='relu'") + out3 = fused_op(x, y, z, scale_factor=1.0, activation="relu") + out3.sum() # Synchronize + + # Launch 4 (different size) + print("Fused Op Launch 4: scale=1.0, activation='relu', different size") + x_large = torch.randn((6,), device="cuda", dtype=torch.float32) + y_large = torch.randn((6,), device="cuda", dtype=torch.float32) + z_large = torch.randn((6,), device="cuda", dtype=torch.float32) + out4 = fused_op(x_large, y_large, z_large, scale_factor=1.0, activation="relu") + out4.sum() # Synchronize + print("All kernels executed.") + + +if __name__ == "__main__": + test_complex_kernels() + # Use unified_parse to process the generated logs + tritonparse.utils.unified_parse( + source=log_path, out="./parsed_output_complex", overwrite=True + ) diff --git a/tritonparse/extract_source_mappings.py b/tritonparse/extract_source_mappings.py index 194f4c9..a3fd972 100644 --- a/tritonparse/extract_source_mappings.py +++ b/tritonparse/extract_source_mappings.py @@ -449,80 +449,80 @@ def parse_single_trace_content(trace_content: str) -> str: """ entry = json.loads(trace_content) - if entry.get("event_type") != "compilation" and "payload" in entry: - logger.warning("Not a compilation event. Skipping.") - return "" - payload = entry.setdefault("payload", {}) - file_content = payload.get("file_content", {}) - file_path = payload.get("file_path", {}) - - # Find the IR file keys - ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) - ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) - ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) - amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) - # Skip if no IR files found - if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): - logger.warning("No IR files found in the payload.") - return trace_content - - # generate ttir->source, ttgir->source, ptx->source - ttir_map = process_ir(ttir_key, file_content, file_path) - ttgir_map = process_ir(ttgir_key, file_content, file_path) - ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) - amdgcn_map = process_ir(amdgcn_key, file_content, file_path, [ttir_map, ttgir_map]) - - # Create bidirectional mappings between all IR types - ir_maps = { - "ttir": ttir_map, - "ttgir": ttgir_map, - "ptx": ptx_map, - "amdgcn": amdgcn_map, - } - - # Create mappings between all pairs of IR types - ir_types = list(ir_maps.keys()) - for i, src_type in enumerate(ir_types): - for tgt_type in ir_types[i + 1 :]: - if ir_maps[src_type] and ir_maps[tgt_type]: - create_bidirectional_mapping( - ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type - ) - logger.debug( - f"Created bidirectional mapping between {src_type} and {tgt_type}" - ) - - py_map = {} - - if "python_source" in payload: - logger.debug( - f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + if entry.get("event_type") == "compilation": + payload = entry.setdefault("payload", {}) + file_content = payload.get("file_content", {}) + file_path = payload.get("file_path", {}) + + # Find the IR file keys + ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) + ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) + ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) + amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) + # Skip if no IR files found + if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): + logger.warning("No IR files found in the payload.") + return trace_content + + # generate ttir->source, ttgir->source, ptx->source + ttir_map = process_ir(ttir_key, file_content, file_path) + ttgir_map = process_ir(ttgir_key, file_content, file_path) + ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) + amdgcn_map = process_ir( + amdgcn_key, file_content, file_path, [ttir_map, ttgir_map] ) - # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. - # Create a list of valid IR mappings, filtering out None keys - ir_mappings = [] - ir_keys_and_maps = [ - (ttir_key, ttir_map), - (ttgir_key, ttgir_map), - (ptx_key, ptx_map), - (amdgcn_key, amdgcn_map), - ] - - for key, mapping in ir_keys_and_maps: - if key: - ir_mappings.append((get_file_extension(key), mapping)) - - py_map = create_python_mapping(ir_mappings) - - # Store the mappings in the payload - payload["source_mappings"] = { - "ttir": ttir_map, - "ttgir": ttgir_map, - **({"ptx": ptx_map} if ptx_map else {}), - **({"amdgcn": amdgcn_map} if amdgcn_map else {}), - "python": py_map, - } + # Create bidirectional mappings between all IR types + ir_maps = { + "ttir": ttir_map, + "ttgir": ttgir_map, + "ptx": ptx_map, + "amdgcn": amdgcn_map, + } + + # Create mappings between all pairs of IR types + ir_types = list(ir_maps.keys()) + for i, src_type in enumerate(ir_types): + for tgt_type in ir_types[i + 1 :]: + if ir_maps[src_type] and ir_maps[tgt_type]: + create_bidirectional_mapping( + ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type + ) + logger.debug( + f"Created bidirectional mapping between {src_type} and {tgt_type}" + ) + + py_map = {} + + if "python_source" in payload: + logger.debug( + f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + ) + + # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. + # Create a list of valid IR mappings, filtering out None keys + ir_mappings = [] + ir_keys_and_maps = [ + (ttir_key, ttir_map), + (ttgir_key, ttgir_map), + (ptx_key, ptx_map), + (amdgcn_key, amdgcn_map), + ] + + for key, mapping in ir_keys_and_maps: + if key: + ir_mappings.append((get_file_extension(key), mapping)) + + py_map = create_python_mapping(ir_mappings) + + # Store the mappings in the payload + payload["source_mappings"] = { + "ttir": ttir_map, + "ttgir": ttgir_map, + **({"ptx": ptx_map} if ptx_map else {}), + **({"amdgcn": amdgcn_map} if amdgcn_map else {}), + "python": py_map, + } # NDJSON format requires a newline at the end of each line return json.dumps(entry, separators=(",", ":")) + "\n" @@ -533,90 +533,134 @@ def parse_single_file( split_by_frame_id_and_compile_id: bool = True, ): """ - Process a single file and extract source code mappings. + Process a single file, correctly group events by kernel, and extract mappings. - This function takes a file path as input, reads the file content, and extracts - source code mappings from Triton trace JSON files. It processes each line of the file, - parses the trace content to extract IR mappings, and writes the updated content - to output files. + This function reads a trace file, groups compilation and launch events by + their kernel hash, generates a launch_diff event for each kernel, and writes + the processed data to output files. Args: file_path (str): The path to the file to be processed. - output_dir (str, optional): Directory to save the output files with mappings. - split_by_frame_id_and_compile_id (bool, optional): Whether to split output files - by frame_id and compile_id. Defaults to True. - - Returns: - None. The function writes the processed data to files in the output_dir. - Each output file will contain the original trace data enriched with source mappings - in NDJSON format (one JSON object per line). + output_dir (str, optional): Directory to save the output files. + split_by_frame_id_and_compile_id (bool, optional): Whether to split + output files by frame_id and compile_id. Defaults to True. """ - outputs = defaultdict(list) + kernels_by_hash = defaultdict( + lambda: {"compilation": None, "launches": [], "output_file": None} + ) - # Set default output directory if not provided output_dir = output_dir or os.path.dirname(file_path) - - # Check if input file is compressed based on file extension is_compressed_input = file_path.endswith(".bin.ndjson") - - # Open file in appropriate mode - use gzip.open for compressed files - if is_compressed_input: - # Use gzip.open which automatically handles member concatenation - file_handle = gzip.open(file_path, "rt", encoding="utf-8") - else: - file_handle = open(file_path, "r") + file_handle = ( + gzip.open(file_path, "rt", encoding="utf-8") + if is_compressed_input + else open(file_path, "r") + ) with file_handle as f: file_name = os.path.basename(file_path) - # Handle .bin.ndjson extension properly - if is_compressed_input: - file_name_without_extension = file_name[:-11] # Remove .bin.ndjson - else: - file_name_without_extension = os.path.splitext(file_name)[0] + file_name_without_extension = ( + file_name[:-11] if is_compressed_input else os.path.splitext(file_name)[0] + ) - # Process lines uniformly for both compressed and uncompressed files for i, line in enumerate(f): logger.debug(f"Processing line {i + 1} in {file_path}") - json_str = line.strip() if not json_str: continue - parsed_line = parse_single_trace_content(json_str) - if not parsed_line: - logger.warning(f"Failed to parse line {i + 1} in {file_path}") + # We don't need to generate full mappings for every line here, + # just enough to get the event type and necessary IDs. + try: + parsed_json = json.loads(json_str) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON on line {i + 1} in {file_path}") continue - parsed_json = json.loads(parsed_line) - payload = parsed_json.get("payload", None) - if split_by_frame_id_and_compile_id: - if not payload: - logger.warning("No payload found in the parsed JSON.") + event_type = parsed_json.get("event_type", None) + payload = parsed_json.get("payload", {}) + + if event_type == "compilation": + kernel_hash = payload.get("metadata", {}).get("hash") + if not kernel_hash: continue - pt_info = payload.get("pt_info", {}) - frame_id = pt_info.get("frame_id", None) - frame_compile_id = pt_info.get("frame_compile_id", None) - compiled_autograd_id = pt_info.get("compiled_autograd_id", "-") - attempt_id = pt_info.get("attempt_id", 0) - output_file_name = "" - if frame_id is not None or frame_compile_id is not None: - output_file_name = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{compiled_autograd_id}.ndjson" + + if split_by_frame_id_and_compile_id: + pt_info = payload.get("pt_info", {}) + frame_id = pt_info.get("frame_id") + frame_compile_id = pt_info.get("frame_compile_id") + attempt_id = pt_info.get("attempt_id", 0) + cai = pt_info.get("compiled_autograd_id", "-") + if frame_id is not None or frame_compile_id is not None: + fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" + else: + fname = f"{file_name_without_extension}_mapped.ndjson" else: - logger.debug( - "No frame_id or frame_compile_id found in the payload." + fname = f"{file_name_without_extension}_mapped.ndjson" + + output_file = os.path.join(output_dir, fname) + # The full processing is deferred until the final write. + kernels_by_hash[kernel_hash]["compilation"] = json_str + kernels_by_hash[kernel_hash]["output_file"] = output_file + + elif event_type == "launch": + kernel_hash = parsed_json.get("compilation_metadata", {}).get("hash") + if kernel_hash: + kernels_by_hash[kernel_hash]["launches"].append( + (parsed_json, i + 1) ) - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - else: - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - output_file = os.path.join(output_dir, output_file_name) - outputs[output_file].append(parsed_line) - logger.debug(f"output file: {output_file}") + + # Organize lines for final output, keyed by output file path + all_output_lines = defaultdict(list) + for _kernel_hash, data in kernels_by_hash.items(): + compilation_json_str = data["compilation"] + launches_with_indices = data["launches"] + output_file = data["output_file"] + + if not output_file: + logger.warning(f"No output file for kernel hash {_kernel_hash}, skipping.") + continue + + # Process the compilation event now to include source mappings + if compilation_json_str: + processed_compilation_line = parse_single_trace_content( + compilation_json_str + ) + all_output_lines[output_file].append(processed_compilation_line) + compilation_event = json.loads(processed_compilation_line) + else: + compilation_event = None + + for launch_event, _ in launches_with_indices: + all_output_lines[output_file].append( + json.dumps(launch_event, separators=(",", ":")) + "\n" + ) + + if compilation_event and launches_with_indices: + sames, diffs, launch_index_map = _generate_launch_diff( + launches_with_indices + ) + launch_diff_event = { + "event_type": "launch_diff", + "hash": _kernel_hash, + "name": compilation_event.get("payload", {}) + .get("metadata", {}) + .get("name"), + "total_launches": len(launches_with_indices), + "launch_index_map": launch_index_map, + "diffs": diffs, + "sames": sames, + } + all_output_lines[output_file].append( + json.dumps(launch_diff_event, separators=(",", ":")) + "\n" + ) if not os.path.exists(output_dir): os.makedirs(output_dir) - for output_file, parsed_lines in outputs.items(): + + for output_file, final_lines in all_output_lines.items(): with open(output_file, "w") as out: - out.writelines(parsed_lines) + out.writelines(final_lines) def parse_args(): @@ -638,6 +682,178 @@ def parse_args(): return parser.parse_args() +# Fields that are expected to vary but are not useful to list out in the diff. +SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"] + + +def _flatten_dict( + d: Dict[str, Any], parent_key: str = "", sep: str = "." +) -> Dict[str, Any]: + """ + Flattens a nested dictionary. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _unflatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """ + Unflattens a dictionary with delimited keys. + """ + result = {} + for key, value in d.items(): + parts = key.split(sep) + d_ref = result + for part in parts[:-1]: + if part not in d_ref: + d_ref[part] = {} + d_ref = d_ref[part] + d_ref[parts[-1]] = value + return result + + +def _to_ranges(indices: List[int]) -> List[Dict[str, int]]: + """ + Converts a sorted list of indices into a list of continuous ranges. + e.g., [0, 1, 2, 5, 6, 8] -> [{'start': 0, 'end': 2}, {'start': 5, 'end': 6}, {'start': 8, 'end': 8}] + """ + if not indices: + return [] + + indices = sorted(indices) + ranges = [] + start = indices[0] + end = indices[0] + + for i in range(1, len(indices)): + if indices[i] == end + 1: + end = indices[i] + else: + ranges.append({"start": start, "end": end}) + start = end = indices[i] + + ranges.append({"start": start, "end": end}) + return ranges + + +def _generate_launch_diff( + launches: List[Tuple[Dict[str, Any], int]], +) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]: + """ + Compares a list of launch events and returns sames, diffs, and an index map. + """ + if not launches: + return {}, {}, [] + + launch_events = [launch[0] for launch in launches] + launch_index_map = [launch[1] for launch in launches] + + if len(launch_events) == 1: + return ( + _unflatten_dict(_flatten_dict(launch_events[0])), + {}, + _to_ranges(launch_index_map), + ) + + # Group values by key + data_by_key = defaultdict(lambda: defaultdict(list)) + for i, launch in enumerate(launch_events): + launch_flat = _flatten_dict(launch) + for key, value in launch_flat.items(): + # JSON doesn't support all Python types as values directly, str is safer + value_str = json.dumps(value, sort_keys=True) + data_by_key[key][value_str].append(i) + + sames_flat = {} + diffs_flat = {} + + for key, value_groups in data_by_key.items(): + if len(value_groups) == 1: + # This key has the same value across all launches + value_str = list(value_groups.keys())[0] + sames_flat[key] = json.loads(value_str) + else: + # This key has different values + is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS) + if is_summary: + diffs_flat[key] = { + "diff_type": "summary", + "summary_text": f"Varies across {len(value_groups)} unique values", + } + else: + values_dist = [] + for value_str, indices in value_groups.items(): + values_dist.append( + { + "value": json.loads(value_str), + "count": len(indices), + "launches": _to_ranges(indices), + } + ) + # Sort by first occurrence + values_dist.sort(key=lambda x: x["launches"][0]["start"]) + diffs_flat[key] = { + "diff_type": "distribution", + "values": values_dist, + } + + # Unflatten the results + sames_unflattened = _unflatten_dict(sames_flat) + diffs_unflattened = _unflatten_dict(diffs_flat) + + # Special handling for extracted_args to create argument_diff structures + if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened: + sames_args = sames_unflattened.pop("extracted_args", {}) + diffs_args_flat = diffs_unflattened.pop("extracted_args", {}) + + all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys()) + + final_arg_diffs = {} + + for arg_name in all_arg_names: + if arg_name in diffs_args_flat: + # This argument has at least one differing sub-field. + arg_sames = {} + arg_diffs_internal = {} + + # Collect all sub-fields for this argument from the original data + all_sub_fields = set() + for launch in launch_events: + arg_data = launch.get("extracted_args", {}).get(arg_name, {}) + all_sub_fields.update(arg_data.keys()) + + for sub_field in all_sub_fields: + flat_key = f"extracted_args.{arg_name}.{sub_field}" + if flat_key in diffs_flat: + arg_diffs_internal[sub_field] = diffs_flat[flat_key] + elif flat_key in sames_flat: + arg_sames[sub_field] = sames_flat[flat_key] + + if arg_sames or arg_diffs_internal: + final_arg_diffs[arg_name] = { + "diff_type": "argument_diff", + "sames": arg_sames, + "diffs": arg_diffs_internal, + } + elif arg_name in sames_args: + # This argument is entirely the same across all launches. + # We move it back to the main sames dict for consistency. + if "extracted_args" not in sames_unflattened: + sames_unflattened["extracted_args"] = {} + sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name] + + if final_arg_diffs: + diffs_unflattened["extracted_args"] = final_arg_diffs + + return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map) + + if __name__ == "__main__": args = parse_args() if args.input: diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 76c6b6d..81711f9 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -156,6 +156,9 @@ def convert(obj): # 2. simple containers ---------------------------------------------------- if isinstance(obj, (list, tuple)): + # Handle namedtuple specially to preserve field names + if hasattr(obj, "_asdict"): + return convert(obj._asdict()) return [convert(x) for x in obj] if isinstance(obj, (set, frozenset)): @@ -810,7 +813,7 @@ def extract_arg_info(arg_dict): def add_launch_metadata(grid, metadata, arg_dict): # Extract detailed argument information extracted_args = extract_arg_info(arg_dict) - return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)} class JITHookImpl(JITHook): @@ -928,7 +931,7 @@ def __call__(self, metadata): ) if launch_metadata_tritonparse is not None: trace_data["grid"] = launch_metadata_tritonparse[0] - trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["compilation_metadata"] = launch_metadata_tritonparse[1] trace_data["extracted_args"] = launch_metadata_tritonparse[ 2 ] # Now contains detailed arg info diff --git a/website/src/components/ArgumentViewer.tsx b/website/src/components/ArgumentViewer.tsx new file mode 100644 index 0000000..9925bb4 --- /dev/null +++ b/website/src/components/ArgumentViewer.tsx @@ -0,0 +1,134 @@ +import React, { useState } from 'react'; + +// Renders the value distribution (e.g., "16 (2 times, in launches: 1-2)") +const DistributionCell: React.FC<{ data: any }> = ({ data }) => { + if (!data) return null; + if (data.diff_type === 'summary') { + return {data.summary_text}; + } + if (data.diff_type === 'distribution' && data.values) { + return ( +
{JSON.stringify(argData, null, 2)}+ } +
No differing fields detected.
+ ); + } + + // Separate different kinds of diffs + const extractedArgs = diffs.extracted_args; + const stackDiff = diffs.stack; + const otherDiffs = Object.fromEntries( + Object.entries(diffs).filter( + ([key]) => key !== "extracted_args" && key !== "stack" + ) + ); + + const renderSimpleDiff = (_key: string, data: any) => { + if (data.diff_type === "summary") { + return{data.summary_text}
; + } + if (data.diff_type === "distribution") { + return ( +{JSON.stringify(data, null, 2)}; + }; + + return ( +
+ Variant seen {item.count} times (in launches: {launchRanges}) +
+Invalid stack format
} ++ Total Launches:{" "} + {kernel.launchDiff.total_launches} +
- {/* Other Metadata */} -