From b84d3150687c02787075ec313d8d60a507f4cc64 Mon Sep 17 00:00:00 2001 From: FindHao Date: Sun, 20 Jul 2025 17:14:50 -0400 Subject: [PATCH 1/4] Refactor trace content parsing and enhance source mapping extraction Summary: - Updated `parse_single_trace_content` to only process compilation events, improving efficiency by skipping non-compilation events. - Enhanced the handling of IR file mappings and bidirectional mappings between different IR types. - Improved the organization of output lines for final writing, ensuring accurate grouping of compilation and launch events by kernel hash. - Added new utility functions for flattening and unflattening dictionaries, and for generating launch diffs, enhancing the overall functionality of the module. These changes aim to streamline the parsing process and improve the clarity and maintainability of the codebase. --- tritonparse/extract_source_mappings.py | 474 ++++++++++++++++++------- tritonparse/structured_logging.py | 7 +- 2 files changed, 350 insertions(+), 131 deletions(-) diff --git a/tritonparse/extract_source_mappings.py b/tritonparse/extract_source_mappings.py index 194f4c9..a3fd972 100644 --- a/tritonparse/extract_source_mappings.py +++ b/tritonparse/extract_source_mappings.py @@ -449,80 +449,80 @@ def parse_single_trace_content(trace_content: str) -> str: """ entry = json.loads(trace_content) - if entry.get("event_type") != "compilation" and "payload" in entry: - logger.warning("Not a compilation event. Skipping.") - return "" - payload = entry.setdefault("payload", {}) - file_content = payload.get("file_content", {}) - file_path = payload.get("file_path", {}) - - # Find the IR file keys - ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) - ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) - ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) - amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) - # Skip if no IR files found - if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): - logger.warning("No IR files found in the payload.") - return trace_content - - # generate ttir->source, ttgir->source, ptx->source - ttir_map = process_ir(ttir_key, file_content, file_path) - ttgir_map = process_ir(ttgir_key, file_content, file_path) - ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) - amdgcn_map = process_ir(amdgcn_key, file_content, file_path, [ttir_map, ttgir_map]) - - # Create bidirectional mappings between all IR types - ir_maps = { - "ttir": ttir_map, - "ttgir": ttgir_map, - "ptx": ptx_map, - "amdgcn": amdgcn_map, - } - - # Create mappings between all pairs of IR types - ir_types = list(ir_maps.keys()) - for i, src_type in enumerate(ir_types): - for tgt_type in ir_types[i + 1 :]: - if ir_maps[src_type] and ir_maps[tgt_type]: - create_bidirectional_mapping( - ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type - ) - logger.debug( - f"Created bidirectional mapping between {src_type} and {tgt_type}" - ) - - py_map = {} - - if "python_source" in payload: - logger.debug( - f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + if entry.get("event_type") == "compilation": + payload = entry.setdefault("payload", {}) + file_content = payload.get("file_content", {}) + file_path = payload.get("file_path", {}) + + # Find the IR file keys + ttir_key = next((k for k in file_content if k.endswith(".ttir")), None) + ttgir_key = next((k for k in file_content if k.endswith(".ttgir")), None) + ptx_key = next((k for k in file_content if k.endswith(".ptx")), None) + amdgcn_key = next((k for k in file_content if k.endswith(".amdgcn")), None) + # Skip if no IR files found + if not (ttir_key or ttgir_key or ptx_key or amdgcn_key): + logger.warning("No IR files found in the payload.") + return trace_content + + # generate ttir->source, ttgir->source, ptx->source + ttir_map = process_ir(ttir_key, file_content, file_path) + ttgir_map = process_ir(ttgir_key, file_content, file_path) + ptx_map = process_ir(ptx_key, file_content, file_path, [ttir_map, ttgir_map]) + amdgcn_map = process_ir( + amdgcn_key, file_content, file_path, [ttir_map, ttgir_map] ) - # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. - # Create a list of valid IR mappings, filtering out None keys - ir_mappings = [] - ir_keys_and_maps = [ - (ttir_key, ttir_map), - (ttgir_key, ttgir_map), - (ptx_key, ptx_map), - (amdgcn_key, amdgcn_map), - ] - - for key, mapping in ir_keys_and_maps: - if key: - ir_mappings.append((get_file_extension(key), mapping)) - - py_map = create_python_mapping(ir_mappings) - - # Store the mappings in the payload - payload["source_mappings"] = { - "ttir": ttir_map, - "ttgir": ttgir_map, - **({"ptx": ptx_map} if ptx_map else {}), - **({"amdgcn": amdgcn_map} if amdgcn_map else {}), - "python": py_map, - } + # Create bidirectional mappings between all IR types + ir_maps = { + "ttir": ttir_map, + "ttgir": ttgir_map, + "ptx": ptx_map, + "amdgcn": amdgcn_map, + } + + # Create mappings between all pairs of IR types + ir_types = list(ir_maps.keys()) + for i, src_type in enumerate(ir_types): + for tgt_type in ir_types[i + 1 :]: + if ir_maps[src_type] and ir_maps[tgt_type]: + create_bidirectional_mapping( + ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type + ) + logger.debug( + f"Created bidirectional mapping between {src_type} and {tgt_type}" + ) + + py_map = {} + + if "python_source" in payload: + logger.debug( + f"Added Python source information (lines {payload['python_source']['start_line']}-{payload['python_source']['end_line']})" + ) + + # 4. Create Python source to IR mappings. We use the original line numbers as key in the python source code. + # Create a list of valid IR mappings, filtering out None keys + ir_mappings = [] + ir_keys_and_maps = [ + (ttir_key, ttir_map), + (ttgir_key, ttgir_map), + (ptx_key, ptx_map), + (amdgcn_key, amdgcn_map), + ] + + for key, mapping in ir_keys_and_maps: + if key: + ir_mappings.append((get_file_extension(key), mapping)) + + py_map = create_python_mapping(ir_mappings) + + # Store the mappings in the payload + payload["source_mappings"] = { + "ttir": ttir_map, + "ttgir": ttgir_map, + **({"ptx": ptx_map} if ptx_map else {}), + **({"amdgcn": amdgcn_map} if amdgcn_map else {}), + "python": py_map, + } # NDJSON format requires a newline at the end of each line return json.dumps(entry, separators=(",", ":")) + "\n" @@ -533,90 +533,134 @@ def parse_single_file( split_by_frame_id_and_compile_id: bool = True, ): """ - Process a single file and extract source code mappings. + Process a single file, correctly group events by kernel, and extract mappings. - This function takes a file path as input, reads the file content, and extracts - source code mappings from Triton trace JSON files. It processes each line of the file, - parses the trace content to extract IR mappings, and writes the updated content - to output files. + This function reads a trace file, groups compilation and launch events by + their kernel hash, generates a launch_diff event for each kernel, and writes + the processed data to output files. Args: file_path (str): The path to the file to be processed. - output_dir (str, optional): Directory to save the output files with mappings. - split_by_frame_id_and_compile_id (bool, optional): Whether to split output files - by frame_id and compile_id. Defaults to True. - - Returns: - None. The function writes the processed data to files in the output_dir. - Each output file will contain the original trace data enriched with source mappings - in NDJSON format (one JSON object per line). + output_dir (str, optional): Directory to save the output files. + split_by_frame_id_and_compile_id (bool, optional): Whether to split + output files by frame_id and compile_id. Defaults to True. """ - outputs = defaultdict(list) + kernels_by_hash = defaultdict( + lambda: {"compilation": None, "launches": [], "output_file": None} + ) - # Set default output directory if not provided output_dir = output_dir or os.path.dirname(file_path) - - # Check if input file is compressed based on file extension is_compressed_input = file_path.endswith(".bin.ndjson") - - # Open file in appropriate mode - use gzip.open for compressed files - if is_compressed_input: - # Use gzip.open which automatically handles member concatenation - file_handle = gzip.open(file_path, "rt", encoding="utf-8") - else: - file_handle = open(file_path, "r") + file_handle = ( + gzip.open(file_path, "rt", encoding="utf-8") + if is_compressed_input + else open(file_path, "r") + ) with file_handle as f: file_name = os.path.basename(file_path) - # Handle .bin.ndjson extension properly - if is_compressed_input: - file_name_without_extension = file_name[:-11] # Remove .bin.ndjson - else: - file_name_without_extension = os.path.splitext(file_name)[0] + file_name_without_extension = ( + file_name[:-11] if is_compressed_input else os.path.splitext(file_name)[0] + ) - # Process lines uniformly for both compressed and uncompressed files for i, line in enumerate(f): logger.debug(f"Processing line {i + 1} in {file_path}") - json_str = line.strip() if not json_str: continue - parsed_line = parse_single_trace_content(json_str) - if not parsed_line: - logger.warning(f"Failed to parse line {i + 1} in {file_path}") + # We don't need to generate full mappings for every line here, + # just enough to get the event type and necessary IDs. + try: + parsed_json = json.loads(json_str) + except json.JSONDecodeError: + logger.warning(f"Failed to parse JSON on line {i + 1} in {file_path}") continue - parsed_json = json.loads(parsed_line) - payload = parsed_json.get("payload", None) - if split_by_frame_id_and_compile_id: - if not payload: - logger.warning("No payload found in the parsed JSON.") + event_type = parsed_json.get("event_type", None) + payload = parsed_json.get("payload", {}) + + if event_type == "compilation": + kernel_hash = payload.get("metadata", {}).get("hash") + if not kernel_hash: continue - pt_info = payload.get("pt_info", {}) - frame_id = pt_info.get("frame_id", None) - frame_compile_id = pt_info.get("frame_compile_id", None) - compiled_autograd_id = pt_info.get("compiled_autograd_id", "-") - attempt_id = pt_info.get("attempt_id", 0) - output_file_name = "" - if frame_id is not None or frame_compile_id is not None: - output_file_name = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{compiled_autograd_id}.ndjson" + + if split_by_frame_id_and_compile_id: + pt_info = payload.get("pt_info", {}) + frame_id = pt_info.get("frame_id") + frame_compile_id = pt_info.get("frame_compile_id") + attempt_id = pt_info.get("attempt_id", 0) + cai = pt_info.get("compiled_autograd_id", "-") + if frame_id is not None or frame_compile_id is not None: + fname = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{cai}.ndjson" + else: + fname = f"{file_name_without_extension}_mapped.ndjson" else: - logger.debug( - "No frame_id or frame_compile_id found in the payload." + fname = f"{file_name_without_extension}_mapped.ndjson" + + output_file = os.path.join(output_dir, fname) + # The full processing is deferred until the final write. + kernels_by_hash[kernel_hash]["compilation"] = json_str + kernels_by_hash[kernel_hash]["output_file"] = output_file + + elif event_type == "launch": + kernel_hash = parsed_json.get("compilation_metadata", {}).get("hash") + if kernel_hash: + kernels_by_hash[kernel_hash]["launches"].append( + (parsed_json, i + 1) ) - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - else: - output_file_name = f"{file_name_without_extension}_mapped.ndjson" - output_file = os.path.join(output_dir, output_file_name) - outputs[output_file].append(parsed_line) - logger.debug(f"output file: {output_file}") + + # Organize lines for final output, keyed by output file path + all_output_lines = defaultdict(list) + for _kernel_hash, data in kernels_by_hash.items(): + compilation_json_str = data["compilation"] + launches_with_indices = data["launches"] + output_file = data["output_file"] + + if not output_file: + logger.warning(f"No output file for kernel hash {_kernel_hash}, skipping.") + continue + + # Process the compilation event now to include source mappings + if compilation_json_str: + processed_compilation_line = parse_single_trace_content( + compilation_json_str + ) + all_output_lines[output_file].append(processed_compilation_line) + compilation_event = json.loads(processed_compilation_line) + else: + compilation_event = None + + for launch_event, _ in launches_with_indices: + all_output_lines[output_file].append( + json.dumps(launch_event, separators=(",", ":")) + "\n" + ) + + if compilation_event and launches_with_indices: + sames, diffs, launch_index_map = _generate_launch_diff( + launches_with_indices + ) + launch_diff_event = { + "event_type": "launch_diff", + "hash": _kernel_hash, + "name": compilation_event.get("payload", {}) + .get("metadata", {}) + .get("name"), + "total_launches": len(launches_with_indices), + "launch_index_map": launch_index_map, + "diffs": diffs, + "sames": sames, + } + all_output_lines[output_file].append( + json.dumps(launch_diff_event, separators=(",", ":")) + "\n" + ) if not os.path.exists(output_dir): os.makedirs(output_dir) - for output_file, parsed_lines in outputs.items(): + + for output_file, final_lines in all_output_lines.items(): with open(output_file, "w") as out: - out.writelines(parsed_lines) + out.writelines(final_lines) def parse_args(): @@ -638,6 +682,178 @@ def parse_args(): return parser.parse_args() +# Fields that are expected to vary but are not useful to list out in the diff. +SUMMARY_FIELDS = ["pid", "timestamp", "stream", "function", "data_ptr"] + + +def _flatten_dict( + d: Dict[str, Any], parent_key: str = "", sep: str = "." +) -> Dict[str, Any]: + """ + Flattens a nested dictionary. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _unflatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """ + Unflattens a dictionary with delimited keys. + """ + result = {} + for key, value in d.items(): + parts = key.split(sep) + d_ref = result + for part in parts[:-1]: + if part not in d_ref: + d_ref[part] = {} + d_ref = d_ref[part] + d_ref[parts[-1]] = value + return result + + +def _to_ranges(indices: List[int]) -> List[Dict[str, int]]: + """ + Converts a sorted list of indices into a list of continuous ranges. + e.g., [0, 1, 2, 5, 6, 8] -> [{'start': 0, 'end': 2}, {'start': 5, 'end': 6}, {'start': 8, 'end': 8}] + """ + if not indices: + return [] + + indices = sorted(indices) + ranges = [] + start = indices[0] + end = indices[0] + + for i in range(1, len(indices)): + if indices[i] == end + 1: + end = indices[i] + else: + ranges.append({"start": start, "end": end}) + start = end = indices[i] + + ranges.append({"start": start, "end": end}) + return ranges + + +def _generate_launch_diff( + launches: List[Tuple[Dict[str, Any], int]], +) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, int]]]: + """ + Compares a list of launch events and returns sames, diffs, and an index map. + """ + if not launches: + return {}, {}, [] + + launch_events = [launch[0] for launch in launches] + launch_index_map = [launch[1] for launch in launches] + + if len(launch_events) == 1: + return ( + _unflatten_dict(_flatten_dict(launch_events[0])), + {}, + _to_ranges(launch_index_map), + ) + + # Group values by key + data_by_key = defaultdict(lambda: defaultdict(list)) + for i, launch in enumerate(launch_events): + launch_flat = _flatten_dict(launch) + for key, value in launch_flat.items(): + # JSON doesn't support all Python types as values directly, str is safer + value_str = json.dumps(value, sort_keys=True) + data_by_key[key][value_str].append(i) + + sames_flat = {} + diffs_flat = {} + + for key, value_groups in data_by_key.items(): + if len(value_groups) == 1: + # This key has the same value across all launches + value_str = list(value_groups.keys())[0] + sames_flat[key] = json.loads(value_str) + else: + # This key has different values + is_summary = any(summary_key in key for summary_key in SUMMARY_FIELDS) + if is_summary: + diffs_flat[key] = { + "diff_type": "summary", + "summary_text": f"Varies across {len(value_groups)} unique values", + } + else: + values_dist = [] + for value_str, indices in value_groups.items(): + values_dist.append( + { + "value": json.loads(value_str), + "count": len(indices), + "launches": _to_ranges(indices), + } + ) + # Sort by first occurrence + values_dist.sort(key=lambda x: x["launches"][0]["start"]) + diffs_flat[key] = { + "diff_type": "distribution", + "values": values_dist, + } + + # Unflatten the results + sames_unflattened = _unflatten_dict(sames_flat) + diffs_unflattened = _unflatten_dict(diffs_flat) + + # Special handling for extracted_args to create argument_diff structures + if "extracted_args" in sames_unflattened or "extracted_args" in diffs_unflattened: + sames_args = sames_unflattened.pop("extracted_args", {}) + diffs_args_flat = diffs_unflattened.pop("extracted_args", {}) + + all_arg_names = set(sames_args.keys()) | set(diffs_args_flat.keys()) + + final_arg_diffs = {} + + for arg_name in all_arg_names: + if arg_name in diffs_args_flat: + # This argument has at least one differing sub-field. + arg_sames = {} + arg_diffs_internal = {} + + # Collect all sub-fields for this argument from the original data + all_sub_fields = set() + for launch in launch_events: + arg_data = launch.get("extracted_args", {}).get(arg_name, {}) + all_sub_fields.update(arg_data.keys()) + + for sub_field in all_sub_fields: + flat_key = f"extracted_args.{arg_name}.{sub_field}" + if flat_key in diffs_flat: + arg_diffs_internal[sub_field] = diffs_flat[flat_key] + elif flat_key in sames_flat: + arg_sames[sub_field] = sames_flat[flat_key] + + if arg_sames or arg_diffs_internal: + final_arg_diffs[arg_name] = { + "diff_type": "argument_diff", + "sames": arg_sames, + "diffs": arg_diffs_internal, + } + elif arg_name in sames_args: + # This argument is entirely the same across all launches. + # We move it back to the main sames dict for consistency. + if "extracted_args" not in sames_unflattened: + sames_unflattened["extracted_args"] = {} + sames_unflattened["extracted_args"][arg_name] = sames_args[arg_name] + + if final_arg_diffs: + diffs_unflattened["extracted_args"] = final_arg_diffs + + return sames_unflattened, diffs_unflattened, _to_ranges(launch_index_map) + + if __name__ == "__main__": args = parse_args() if args.input: diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 76c6b6d..81711f9 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -156,6 +156,9 @@ def convert(obj): # 2. simple containers ---------------------------------------------------- if isinstance(obj, (list, tuple)): + # Handle namedtuple specially to preserve field names + if hasattr(obj, "_asdict"): + return convert(obj._asdict()) return [convert(x) for x in obj] if isinstance(obj, (set, frozenset)): @@ -810,7 +813,7 @@ def extract_arg_info(arg_dict): def add_launch_metadata(grid, metadata, arg_dict): # Extract detailed argument information extracted_args = extract_arg_info(arg_dict) - return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)} class JITHookImpl(JITHook): @@ -928,7 +931,7 @@ def __call__(self, metadata): ) if launch_metadata_tritonparse is not None: trace_data["grid"] = launch_metadata_tritonparse[0] - trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["compilation_metadata"] = launch_metadata_tritonparse[1] trace_data["extracted_args"] = launch_metadata_tritonparse[ 2 ] # Now contains detailed arg info From 76ae5328c7b0cb0c96b0c589a957eeeb0155065f Mon Sep 17 00:00:00 2001 From: FindHao Date: Sun, 20 Jul 2025 17:22:35 -0400 Subject: [PATCH 2/4] Add ArgumentViewer, DiffViewer, and StackDiffViewer components Summary: - Introduced `ArgumentViewer` to display argument distributions and differences in a structured format. - Added `DiffViewer` to handle and present differing fields, integrating `ArgumentViewer` for extracted arguments. - Created `StackDiffViewer` to visualize stack traces with collapsible sections for better user experience. - Updated `KernelOverview` to incorporate the new components for enhanced launch analysis and metadata display. - Enhanced `dataLoader` to support launch_diff events, improving kernel data processing. These changes aim to improve the visualization and analysis of kernel launch data, enhancing the overall functionality of the application. --- website/src/components/ArgumentViewer.tsx | 133 +++++++++++++ website/src/components/DiffViewer.tsx | 96 ++++++++++ website/src/components/StackDiffViewer.tsx | 81 ++++++++ website/src/pages/KernelOverview.tsx | 207 +++++++++++++++------ website/src/utils/dataLoader.ts | 71 ++++--- 5 files changed, 510 insertions(+), 78 deletions(-) create mode 100644 website/src/components/ArgumentViewer.tsx create mode 100644 website/src/components/DiffViewer.tsx create mode 100644 website/src/components/StackDiffViewer.tsx diff --git a/website/src/components/ArgumentViewer.tsx b/website/src/components/ArgumentViewer.tsx new file mode 100644 index 0000000..031e0a9 --- /dev/null +++ b/website/src/components/ArgumentViewer.tsx @@ -0,0 +1,133 @@ +import React, { useState } from 'react'; + +// Renders the value distribution (e.g., "16 (2 times, in launches: 1-2)") +const DistributionCell: React.FC<{ data: any }> = ({ data }) => { + if (!data) return null; + if (data.diff_type === 'summary') { + return {data.summary_text}; + } + if (data.diff_type === 'distribution' && data.values) { + return ( + + ); + } + return {JSON.stringify(data)}; +}; + +// Renders a single row in the ArgumentViewer table +const ArgumentRow: React.FC<{ + argName: string; + argData: any; + isDiffViewer?: boolean; +}> = ({ argName, argData, isDiffViewer = false }) => { + // Case 1: This is a complex argument with internal differences + if (isDiffViewer && argData.diff_type === "argument_diff") { + const [isCollapsed, setIsCollapsed] = useState(false); + const { sames, diffs } = argData; + const hasSames = Object.keys(sames).length > 0; + const hasDiffs = Object.keys(diffs).length > 0; + + return ( +
+
setIsCollapsed(!isCollapsed)} + > +
{argName}
+
+ Complex argument with internal differences + +
+
+ {!isCollapsed && ( +
+ {hasSames && ( +
+
Unchanged Properties
+
+ {Object.entries(sames).map(([key, value]) => ( +
+ {key}: + {JSON.stringify(value)} +
+ ))} +
+
+ )} + {hasDiffs && ( +
+
Differing Properties
+
+ {Object.entries(diffs).map(([key, value]) => ( +
+ {key}: +
+
+ ))} +
+
+ )} +
+ )} +
+ ); + } + + // Case 2: This is a simple argument (in the "Sames" table) + return ( +
+
{argName}
+
{argData.type}
+
+ {typeof argData.value !== 'object' || argData.value === null ? + {String(argData.value)} : +
{JSON.stringify(argData, null, 2)}
+ } +
+
+ ); +}; + +// Main container component +const ArgumentViewer: React.FC<{ args: Record; isDiffViewer?: boolean; }> = ({ args, isDiffViewer = false }) => { + if (!args || Object.keys(args).length === 0) { + return
No arguments to display.
; + } + + // A "complex view" is needed if we are showing diffs and at least one of them is a complex argument_diff + const isComplexView = isDiffViewer && Object.values(args).some(arg => arg.diff_type === 'argument_diff'); + + return ( +
+ {/* Render header only for the simple, non-complex table view */} + {!isComplexView && ( +
+
Argument Name
+
Type
+
Value
+
+ )} + + {/* Rows */} +
+ {Object.entries(args).map(([argName, argData]) => ( + + ))} +
+
+ ); +}; + +export default ArgumentViewer; diff --git a/website/src/components/DiffViewer.tsx b/website/src/components/DiffViewer.tsx new file mode 100644 index 0000000..bf6a978 --- /dev/null +++ b/website/src/components/DiffViewer.tsx @@ -0,0 +1,96 @@ +import ArgumentViewer from "./ArgumentViewer"; +import React from "react"; +import StackDiffViewer from "./StackDiffViewer"; + +interface DiffViewerProps { + diffs: any; +} + +const DiffViewer: React.FC = ({ diffs }) => { + if (!diffs || Object.keys(diffs).length === 0) { + return ( +

No differing fields detected.

+ ); + } + + // Separate different kinds of diffs + const extractedArgs = diffs.extracted_args; + const stackDiff = diffs.stack; + const otherDiffs = Object.fromEntries( + Object.entries(diffs).filter( + ([key]) => key !== "extracted_args" && key !== "stack" + ) + ); + + const renderSimpleDiff = (key: string, data: any) => { + if (data.diff_type === "summary") { + return

{data.summary_text}

; + } + if (data.diff_type === "distribution") { + return ( +
    + {data.values.map((item: any, index: number) => { + const launchRanges = item.launches + .map((r: any) => + r.start === r.end + ? `${r.start + 1}` + : `${r.start + 1}-${r.end + 1}` + ) + .join(", "); + return ( +
  • + + {JSON.stringify(item.value)} + + + ({item.count} times, in launches: {launchRanges}) + +
  • + ); + })} +
+ ); + } + // Fallback for unexpected structures + return
{JSON.stringify(data, null, 2)}
; + }; + + return ( +
+ {extractedArgs && Object.keys(extractedArgs).length > 0 && ( +
+
+ Extracted Arguments +
+ +
+ )} + + {Object.keys(otherDiffs).length > 0 && ( +
+
+ Other Differing Fields +
+
+ {Object.entries(otherDiffs).map(([key, value]) => ( +
+ + {key} + +
{renderSimpleDiff(key, value)}
+
+ ))} +
+
+ )} + + {stackDiff && } + +
+ ); +}; + +export default DiffViewer; \ No newline at end of file diff --git a/website/src/components/StackDiffViewer.tsx b/website/src/components/StackDiffViewer.tsx new file mode 100644 index 0000000..d4269fd --- /dev/null +++ b/website/src/components/StackDiffViewer.tsx @@ -0,0 +1,81 @@ +import React, { useState } from 'react'; +import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; +import { oneLight } from 'react-syntax-highlighter/dist/esm/styles/prism'; + +// A single frame of a stack trace +const StackTraceFrame: React.FC<{ frame: any }> = ({ frame }) => ( +
+ {frame.filename}: + {frame.line} in{" "} + {frame.name} + {frame.line_code && ( +
+ + {frame.line_code} + +
+ )} +
+); + + +const StackDiffViewer: React.FC<{ stackDiff: any }> = ({ stackDiff }) => { + const [isCollapsed, setIsCollapsed] = useState(true); + + if (!stackDiff || stackDiff.diff_type !== 'distribution') { + return null; + } + + return ( +
+
setIsCollapsed(!isCollapsed)} + > + Stack Traces + + + +
+ {!isCollapsed && ( +
+ {stackDiff.values.map((item: any, index: number) => { + const launchRanges = item.launches + .map((r: any) => (r.start === r.end ? `${r.start + 1}` : `${r.start + 1}-${r.end + 1}`)) + .join(", "); + + return ( +
+

+ Variant seen {item.count} times (in launches: {launchRanges}) +

+
+ {Array.isArray(item.value) ? item.value.map((frame: any, frameIndex: number) => ( + + )) :

Invalid stack format

} +
+
+ ); + })} +
+ )} +
+ ); +}; + +export default StackDiffViewer; \ No newline at end of file diff --git a/website/src/pages/KernelOverview.tsx b/website/src/pages/KernelOverview.tsx index b78ed05..3708bdc 100644 --- a/website/src/pages/KernelOverview.tsx +++ b/website/src/pages/KernelOverview.tsx @@ -1,4 +1,6 @@ import React from "react"; +import ArgumentViewer from "../components/ArgumentViewer"; +import DiffViewer from "../components/DiffViewer"; import { ProcessedKernel } from "../utils/dataLoader"; interface KernelOverviewProps { @@ -8,6 +10,14 @@ interface KernelOverviewProps { onViewIR: (irType: string) => void; } +/** + * Determines if a metadata value is considered "long" and should be displayed at the end + */ +const isLongValue = (value: any): boolean => { + const formattedString = formatMetadataValue(value); + return formattedString.length > 50; +}; + /** * Formats a value for display in the metadata section * @param value - The value to format @@ -41,15 +51,15 @@ interface MetadataItemProps { const MetadataItem: React.FC = ({ label, value, - span = 1 + span = 1, }) => ( -
1 ? `col-span-${span}` : ''}`}> - - {label} - - - {value} - +
1 ? `col-span-${span}` : ""} ${ + span === 0 ? "col-span-full" : "" + }`} + > + {label} + {value}
); @@ -121,69 +131,152 @@ const KernelOverview: React.FC = ({ Compilation Metadata
-
- {/* Hash */} - {kernel.metadata.hash && ( - - )} - - {/* Target Info */} - {kernel.metadata.target && ( - <> - - - - - )} + {/* Short fields in responsive grid */} +
+ {/* All short metadata fields */} + {Object.entries(kernel.metadata) + .filter(([_key, value]) => !isLongValue(value)) + .map(([key, value]) => { + return ( + + word.charAt(0).toUpperCase() + word.slice(1) + ) + .join(" ")} + value={formatMetadataValue(value)} + /> + ); + })} +
- {/* Execution Configuration */} - - - + {/* Long fields in separate section within same container */} + {Object.entries(kernel.metadata).filter(([_key, value]) => + isLongValue(value) + ).length > 0 && ( +
+ {Object.entries(kernel.metadata) + .filter(([_key, value]) => isLongValue(value)) + .map(([key, value]) => ( +
+ + {key + .split("_") + .map( + (word) => + word.charAt(0).toUpperCase() + word.slice(1) + ) + .join(" ")} + + + {formatMetadataValue(value)} + +
+ ))} +
+ )} +
+
+ )} - {/* Cluster Dimensions */} - {kernel.metadata.cluster_dims && ( - - )} + {/* Launch Analysis Section */} + {kernel.launchDiff && ( +
+

+ Launch Analysis +

+
+

+ Total Launches:{" "} + {kernel.launchDiff.total_launches} +

- {/* Other Metadata */} - - + {/* Launch Index Map */} + {kernel.launchDiff.launch_index_map && ( +
+

+ Launch Locations in Original Trace{" "} + + (1-based line numbers) + +

+
+ {kernel.launchDiff.launch_index_map + .map((r: any) => + r.start === r.end + ? `${r.start}` + : `${r.start}-${r.end}` + ) + .join(", ")} +
+
+ )} - {/* Supported FP8 Types */} - {kernel.metadata.supported_fp8_dtypes && - kernel.metadata.supported_fp8_dtypes.length > 0 && ( - - )} + {/* Unchanged Fields */} + {kernel.launchDiff.sames && Object.keys(kernel.launchDiff.sames).length > 0 && ( +
+

+ Unchanged Launch Arguments +

+ +
+ )} - {/* Additional metadata fields */} - {Object.entries(kernel.metadata) - .filter( + {(() => { + const otherSames = Object.fromEntries( + Object.entries(kernel.launchDiff.sames).filter( ([key]) => - ![ - "hash", - "target", - "num_warps", - "num_ctas", - "num_stages", - "cluster_dims", - "enable_fp_fusion", - "launch_cooperative_grid", - "supported_fp8_dtypes", - ].includes(key) + key !== "compilation_metadata" && + key !== "extracted_args" && + key !== "event_type" ) - .map(([key, value]) => ( - word.charAt(0).toUpperCase() + word.slice(1)).join(" ")} value={formatMetadataValue(value)} /> - ))} + ); + + if (Object.keys(otherSames).length > 0) { + return ( +
+

+ Other Unchanged Fields +

+
+ {Object.entries(otherSames).map(([key, value]) => ( + + word.charAt(0).toUpperCase() + word.slice(1) + ) + .join(" ")} + value={formatMetadataValue(value)} + /> + ))} +
+
+ ); + } + return null; + })()} + + {/* Differing Fields */} +
+

+ Differing Fields +

+
)} - + {/* Stack Trace */}

- Stack Trace + Compilation Stack Trace

{kernel.stack.map((entry, index) => ( diff --git a/website/src/utils/dataLoader.ts b/website/src/utils/dataLoader.ts index ed8e84c..84f81d0 100644 --- a/website/src/utils/dataLoader.ts +++ b/website/src/utils/dataLoader.ts @@ -99,6 +99,12 @@ export interface LogEntry { source_mappings?: Record>; // Alternative field name for source_mapping python_source?: PythonSourceCodeInfo; }; + // Fields for launch_diff event type + hash?: string; + name?: string; + total_launches?: number; + diffs?: any; + sames?: any; } /** @@ -113,6 +119,7 @@ export interface ProcessedKernel { sourceMappings?: Record>; // Source mappings for each IR file pythonSourceInfo?: PythonSourceCodeInfo; // Python source code information metadata?: KernelMetadata; // Compilation metadata + launchDiff?: any; // Aggregated launch event differences } /** @@ -121,6 +128,7 @@ export interface ProcessedKernel { * @returns Array of LogEntry objects */ export function parseLogData(textData: string): LogEntry[] { + console.log("Starting to parse NDJSON data..."); if (typeof textData !== 'string') { throw new Error("Input must be a string in NDJSON format"); } @@ -142,9 +150,11 @@ export function parseLogData(textData: string): LogEntry[] { } if (entries.length === 0) { + console.error("No valid JSON entries found in NDJSON data"); throw new Error("No valid JSON entries found in NDJSON data"); } + console.log(`Successfully parsed ${entries.length} log entries.`); return entries; } catch (error) { console.error("Error parsing NDJSON data:", error); @@ -274,12 +284,19 @@ export function loadLogDataFromFile(file: File): Promise { * @returns Array of processed kernel objects ready for display */ export function processKernelData(logEntries: LogEntry[]): ProcessedKernel[] { - const kernels: ProcessedKernel[] = []; - for (let i = 0; i < logEntries.length; i++) { - const entry = logEntries[i]; - // Check for kernel events by event_type + console.log("Processing kernel data... Total entries:", logEntries.length); + const kernelsByHash: Map = new Map(); + + // First pass: process all compilation events + for (const entry of logEntries) { if (entry.event_type === "compilation" && entry.payload) { - // Ensure payload has file_path and file_content + const hash = entry.payload.metadata?.hash; + if (!hash) { + console.warn("Compilation event missing hash", entry); + continue; + } + console.log(`Processing compilation event for hash: ${hash}`) + if (!entry.payload.file_path || !entry.payload.file_content) { console.warn( "Kernel event missing file_path or file_content", @@ -287,10 +304,9 @@ export function processKernelData(logEntries: LogEntry[]): ProcessedKernel[] { ); continue; } - // Extract kernel name from IR filename + const irFileNames = Object.keys(entry.payload.file_path); let kernelName = "unknown_kernel"; - // Use first IR file name to determine kernel name if (irFileNames.length > 0) { const fileName = irFileNames[0]; const nameParts = fileName.split("."); @@ -300,19 +316,9 @@ export function processKernelData(logEntries: LogEntry[]): ProcessedKernel[] { : fileName; } - // Extract source mapping information from payload if available - let sourceMappings: Record< - string, - Record - > = {}; + const sourceMappings = entry.payload.source_mappings || {}; - if (entry.payload.source_mappings) { - // Use source mappings from the trace file - sourceMappings = entry.payload.source_mappings; - } - - // Create processed kernel object and add to results - kernels.push({ + const newKernel: ProcessedKernel = { name: kernelName, sourceFiles: entry.stack?.map(entry => typeof entry.filename === 'string' ? entry.filename : @@ -324,9 +330,32 @@ export function processKernelData(logEntries: LogEntry[]): ProcessedKernel[] { sourceMappings, pythonSourceInfo: entry.payload.python_source, metadata: entry.payload.metadata, - }); + }; + kernelsByHash.set(hash, newKernel); + console.log(`Stored kernel ${kernelName} with hash ${hash} into map.`); + } + } + console.log(`Finished first pass. Total kernels processed: ${kernelsByHash.size}`); + + // Second pass: attach launch_diff events + console.log("Starting second pass to attach launch_diff events..."); + for (const entry of logEntries) { + if (entry.event_type === "launch_diff") { // No payload for launch_diff + console.log("Found a launch_diff event:", entry); + const hash = entry.hash; + console.log(`launch_diff event hash: ${hash}`); + if (hash && kernelsByHash.has(hash)) { + const kernel = kernelsByHash.get(hash)!; + console.log(`Found matching kernel for hash ${hash}. Attaching launch_diff.`); + kernel.launchDiff = entry; // Attach the entire event object + } else { + console.warn(`Could not find matching kernel for launch_diff hash: ${hash}`); + } } } - return kernels; + + const finalKernels = Array.from(kernelsByHash.values()); + console.log("Finished processing. Final kernel objects:", finalKernels); + return finalKernels; } From b82eee81b7915f34e6c6089a133d7c381eb6cae1 Mon Sep 17 00:00:00 2001 From: FindHao Date: Sun, 20 Jul 2025 17:30:57 -0400 Subject: [PATCH 3/4] Enhance ArgumentViewer, DiffViewer, and StackDiffViewer components Summary: - Added comments for clarity, including a note for the dropdown arrow icon in both `ArgumentViewer` and `StackDiffViewer`. - Updated the `renderSimpleDiff` function in `DiffViewer` to use a more descriptive parameter name. These changes improve code readability and maintainability across the components. --- website/src/components/ArgumentViewer.tsx | 1 + website/src/components/DiffViewer.tsx | 2 +- website/src/components/StackDiffViewer.tsx | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/website/src/components/ArgumentViewer.tsx b/website/src/components/ArgumentViewer.tsx index 031e0a9..9925bb4 100644 --- a/website/src/components/ArgumentViewer.tsx +++ b/website/src/components/ArgumentViewer.tsx @@ -48,6 +48,7 @@ const ArgumentRow: React.FC<{
{argName}
Complex argument with internal differences + {/* Dropdown arrow icon */}
diff --git a/website/src/components/DiffViewer.tsx b/website/src/components/DiffViewer.tsx index bf6a978..481598b 100644 --- a/website/src/components/DiffViewer.tsx +++ b/website/src/components/DiffViewer.tsx @@ -22,7 +22,7 @@ const DiffViewer: React.FC = ({ diffs }) => { ) ); - const renderSimpleDiff = (key: string, data: any) => { + const renderSimpleDiff = (_key: string, data: any) => { if (data.diff_type === "summary") { return

{data.summary_text}

; } diff --git a/website/src/components/StackDiffViewer.tsx b/website/src/components/StackDiffViewer.tsx index d4269fd..06b06ee 100644 --- a/website/src/components/StackDiffViewer.tsx +++ b/website/src/components/StackDiffViewer.tsx @@ -42,6 +42,7 @@ const StackDiffViewer: React.FC<{ stackDiff: any }> = ({ stackDiff }) => { onClick={() => setIsCollapsed(!isCollapsed)} > Stack Traces + {/* Dropdown arrow icon */} Date: Sun, 20 Jul 2025 17:39:35 -0400 Subject: [PATCH 4/4] Add complex kernel tests for Triton Summary: - Introduced a new test file `test_complex_kernels.py` to validate the functionality of two distinct Triton kernels: an autotuned matrix multiplication kernel and a fused element-wise operation kernel. - Implemented a comprehensive test plan that includes multiple launches with varied parameters for both kernels, ensuring thorough testing of their performance and correctness. - Enhanced logging capabilities during the tests to facilitate debugging and analysis of kernel launches. These changes aim to improve the testing framework for Triton kernels, ensuring robust validation of their functionality and performance. --- tests/test_complex_kernels.py | 256 ++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 tests/test_complex_kernels.py diff --git a/tests/test_complex_kernels.py b/tests/test_complex_kernels.py new file mode 100644 index 0000000..389d8a4 --- /dev/null +++ b/tests/test_complex_kernels.py @@ -0,0 +1,256 @@ +""" +A more complex test case involving two distinct Triton kernels, one of which uses autotuning. +This test is designed to validate the launch_diff functionality with multiple, varied launches. + +Test Plan: +``` +TORCHINDUCTOR_FX_GRAPH_CACHE=0 TRITONPARSE_DEBUG=1 python tests/test_complex_kernels.py +``` +""" + +import os + +import torch +import triton +import triton.language as tl + +import tritonparse.structured_logging +import tritonparse.utils + +# Initialize logging +log_path = "./logs_complex" +tritonparse.structured_logging.init(log_path, enable_trace_launch=True) + +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0" + + +# Kernel 1: Autotuned Matmul (simplified configs for small scale) +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + }, + num_stages=1, + num_warps=1, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a, + b, + c, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size) + pid_n = (pid % num_pid_in_group) // group_size + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_block = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + b_block = tl.load( + b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator += tl.dot(a_block, b_block) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c_block = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_block, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ACTIVATION=None, + ) + return c + + +# Kernel 2: Fused element-wise operation +@triton.jit +def fused_op_kernel( + a_ptr, + b_ptr, + c_ptr, + output_ptr, + n_elements, + scale_factor: float, + ACTIVATION: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + c = tl.load(c_ptr + offsets, mask=mask) + + result = a * b * scale_factor + c + if ACTIVATION == "relu": + result = tl.where(result > 0, result, 0.0) + + tl.store(output_ptr + offsets, result, mask=mask) + + +def fused_op(a, b, c, scale_factor: float, activation: str): + n_elements = a.numel() + output = torch.empty_like(a) + BLOCK_SIZE = 8 # Reduced from 1024 for small scale testing + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + fused_op_kernel[grid]( + a, + b, + c, + output, + n_elements, + scale_factor, + ACTIVATION=activation, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +def test_complex_kernels(): + """Main test function to run both kernels with varied parameters.""" + torch.manual_seed(0) + + # --- Matmul Launches (3 times with different configs) --- + print("--- Testing Matmul Kernel (3 launches) ---") + # Launch 1 + a1 = torch.randn((16, 16), device="cuda", dtype=torch.float16) + b1 = torch.randn((16, 16), device="cuda", dtype=torch.float16) + c1 = matmul(a1, b1) + c1.sum() # Synchronize + print("Matmul Launch 1 (16x16 @ 16x16) done.") + + # Launch 2 + a2 = torch.randn((32, 16), device="cuda", dtype=torch.float16) + b2 = torch.randn((16, 32), device="cuda", dtype=torch.float16) + c2 = matmul(a2, b2) + c2.sum() # Synchronize + print("Matmul Launch 2 (32x16 @ 16x32) done.") + + # Launch 3 + a3 = torch.randn((16, 32), device="cuda", dtype=torch.float16) + b3 = torch.randn((32, 16), device="cuda", dtype=torch.float16) + c3 = matmul(a3, b3) + c3.sum() # Synchronize + print("Matmul Launch 3 (16x32 @ 32x16) done.") + + # --- Fused Op Launches (4 times with different parameters) --- + print("\n--- Testing Fused Op Kernel (4 launches) ---") + x = torch.randn((8,), device="cuda", dtype=torch.float32) + y = torch.randn((8,), device="cuda", dtype=torch.float32) + z = torch.randn((8,), device="cuda", dtype=torch.float32) + + # Launch 1 + print("Fused Op Launch 1: scale=1.0, activation=None") + out1 = fused_op(x, y, z, scale_factor=1.0, activation="none") + out1.sum() # Synchronize + + # Launch 2 + print("Fused Op Launch 2: scale=2.5, activation=None") + out2 = fused_op(x, y, z, scale_factor=2.5, activation="none") + out2.sum() # Synchronize + + # Launch 3 + print("Fused Op Launch 3: scale=1.0, activation='relu'") + out3 = fused_op(x, y, z, scale_factor=1.0, activation="relu") + out3.sum() # Synchronize + + # Launch 4 (different size) + print("Fused Op Launch 4: scale=1.0, activation='relu', different size") + x_large = torch.randn((6,), device="cuda", dtype=torch.float32) + y_large = torch.randn((6,), device="cuda", dtype=torch.float32) + z_large = torch.randn((6,), device="cuda", dtype=torch.float32) + out4 = fused_op(x_large, y_large, z_large, scale_factor=1.0, activation="relu") + out4.sum() # Synchronize + print("All kernels executed.") + + +if __name__ == "__main__": + test_complex_kernels() + # Use unified_parse to process the generated logs + tritonparse.utils.unified_parse( + source=log_path, out="./parsed_output_complex", overwrite=True + )