diff --git a/tests/test_add.py b/tests/test_add.py index 9bf1168..94ca07f 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -7,6 +7,8 @@ ``` """ +import os + import torch import triton import triton.language as tl @@ -15,7 +17,9 @@ import tritonparse.utils log_path = "./logs" -tritonparse.structured_logging.init(log_path) +tritonparse.structured_logging.init(log_path, enable_trace_launch=True) + +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0" @triton.jit diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index d36b620..d01a060 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -6,6 +6,7 @@ ``` """ +import json import os import shutil import tempfile @@ -170,12 +171,7 @@ def test_whole_workflow(self): # Define a simple kernel directly in the test function @triton.jit - def test_kernel( - x_ptr, - y_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): + def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -189,48 +185,115 @@ def test_kernel( def run_test_kernel(x): n_elements = x.numel() y = torch.empty_like(x) - BLOCK_SIZE = 256 # Smaller block size for simplicity + BLOCK_SIZE = 256 grid = (triton.cdiv(n_elements, BLOCK_SIZE),) test_kernel[grid](x, y, n_elements, BLOCK_SIZE) return y + # Set up test environment temp_dir = tempfile.mkdtemp() - print(f"Temporary directory: {temp_dir}") temp_dir_logs = os.path.join(temp_dir, "logs") - os.makedirs(temp_dir_logs, exist_ok=True) temp_dir_parsed = os.path.join(temp_dir, "parsed_output") + os.makedirs(temp_dir_logs, exist_ok=True) os.makedirs(temp_dir_parsed, exist_ok=True) + print(f"Temporary directory: {temp_dir}") - tritonparse.structured_logging.init(temp_dir_logs) + # Initialize logging + tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) - # Generate some triton compilation activity to create log files + # Generate test data and run kernels torch.manual_seed(0) size = (512, 512) # Smaller size for faster testing x = torch.randn(size, device=self.cuda_device, dtype=torch.float32) - run_test_kernel(x) # Run the simple kernel + + # Run kernel twice to generate compilation and launch events + run_test_kernel(x) + run_test_kernel(x) torch.cuda.synchronize() - # Check that temp_dir_logs folder has content + # Verify log directory assert os.path.exists( temp_dir_logs ), f"Log directory {temp_dir_logs} does not exist." log_files = os.listdir(temp_dir_logs) - assert ( - len(log_files) > 0 - ), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation." + assert len(log_files) > 0, ( + f"No log files found in {temp_dir_logs}. " + "Expected log files to be generated during Triton compilation." + ) print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}") + def parse_log_line(line: str, line_num: int) -> dict | None: + """Parse a single log line and extract event data""" + try: + return json.loads(line.strip()) + except json.JSONDecodeError as e: + print(f" Line {line_num}: JSON decode error - {e}") + return None + + def process_event_data( + event_data: dict, line_num: int, event_counts: dict + ) -> None: + """Process event data and update counts""" + try: + event_type = event_data.get("event_type") + if event_type is None: + return + + if event_type in event_counts: + event_counts[event_type] += 1 + print( + f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})" + ) + else: + print( + f" Line {line_num}: event_type = '{event_type}' (not tracked)" + ) + except (KeyError, TypeError) as e: + print(f" Line {line_num}: Data structure error - {e}") + + def count_events_in_file(file_path: str, event_counts: dict) -> None: + """Count events in a single log file""" + print(f"Checking event types in: {os.path.basename(file_path)}") + + with open(file_path, "r") as f: + for line_num, line in enumerate(f, 1): + event_data = parse_log_line(line, line_num) + if event_data: + process_event_data(event_data, line_num, event_counts) + + def check_event_type_counts_in_logs(log_dir: str) -> dict: + """Count 'launch' and 'compilation' events in all log files""" + event_counts = {"compilation": 0, "launch": 0} + + for log_file in os.listdir(log_dir): + if log_file.endswith(".ndjson"): + log_file_path = os.path.join(log_dir, log_file) + count_events_in_file(log_file_path, event_counts) + + print(f"Event type counts: {event_counts}") + return event_counts + + # Verify event counts + event_counts = check_event_type_counts_in_logs(temp_dir_logs) + assert ( + event_counts["compilation"] == 1 + ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" + assert ( + event_counts["launch"] == 2 + ), f"Expected 2 'launch' events, found {event_counts['launch']}" + print("✓ Verified correct event type counts: 1 compilation, 2 launch") + + # Test parsing functionality tritonparse.utils.unified_parse( source=temp_dir_logs, out=temp_dir_parsed, overwrite=True ) - # Clean up temporary directory - try: - # Check that parsed output directory has files - parsed_files = os.listdir(temp_dir_parsed) - assert len(parsed_files) > 0, "No files found in parsed output directory" - finally: - shutil.rmtree(temp_dir) + # Verify parsing output + parsed_files = os.listdir(temp_dir_parsed) + assert len(parsed_files) > 0, "No files found in parsed output directory" + + # Clean up + shutil.rmtree(temp_dir) if __name__ == "__main__": diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 7b38dbd..d4bd35e 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -981,13 +981,16 @@ def init_basic(trace_folder: Optional[str] = None): maybe_enable_trace_launch() -def init(trace_folder: Optional[str] = None): +def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False): """ This function is a wrapper around init_basic() that also setup the compilation listener. Args: trace_folder (Optional[str]): The folder to store the trace files. """ + global TRITON_TRACE_LAUNCH + if enable_trace_launch: + TRITON_TRACE_LAUNCH = True import triton init_basic(trace_folder)