From f24e8e773daed8dabca5cff4f9c201eb8ff2e038 Mon Sep 17 00:00:00 2001 From: FindHao Date: Fri, 11 Jul 2025 23:42:41 -0400 Subject: [PATCH 1/4] Enhance structured logging initialization and test coverage Summary: - Updated `tritonparse.structured_logging.init` to accept an `enable_trace_launch` parameter, allowing for improved logging capabilities. - Modified test files to utilize the new logging feature and verify event type counts in generated log files, ensuring accurate tracking of 'launch' and 'compilation' events. - Added checks in `test_tritonparse.py` to assert the expected counts of log events, enhancing test robustness. These changes aim to improve the logging functionality and ensure comprehensive testing of the logging behavior. --- tests/test_add.py | 6 +++- tests/test_tritonparse.py | 55 ++++++++++++++++++++++++++++++- tritonparse/structured_logging.py | 5 ++- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tests/test_add.py b/tests/test_add.py index 9bf1168..94ca07f 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -7,6 +7,8 @@ ``` """ +import os + import torch import triton import triton.language as tl @@ -15,7 +17,9 @@ import tritonparse.utils log_path = "./logs" -tritonparse.structured_logging.init(log_path) +tritonparse.structured_logging.init(log_path, enable_trace_launch=True) + +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0" @triton.jit diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index d36b620..3e51ab7 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -6,6 +6,7 @@ ``` """ +import json import os import shutil import tempfile @@ -201,13 +202,14 @@ def run_test_kernel(x): temp_dir_parsed = os.path.join(temp_dir, "parsed_output") os.makedirs(temp_dir_parsed, exist_ok=True) - tritonparse.structured_logging.init(temp_dir_logs) + tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) # Generate some triton compilation activity to create log files torch.manual_seed(0) size = (512, 512) # Smaller size for faster testing x = torch.randn(size, device=self.cuda_device, dtype=torch.float32) run_test_kernel(x) # Run the simple kernel + run_test_kernel(x) # Run the simple kernel again torch.cuda.synchronize() # Check that temp_dir_logs folder has content @@ -220,6 +222,57 @@ def run_test_kernel(x): ), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation." print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}") + # Check that log files contain specific counts of 'launch' and 'compilation' event types + def check_event_type_counts_in_logs(log_dir): + """Check that log files contain exactly 1 'compilation' and 1 'launch' event types""" + event_type_counts = {"compilation": 0, "launch": 0} + + for log_file in os.listdir(log_dir): + if log_file.endswith(".ndjson"): + log_file_path = os.path.join(log_dir, log_file) + print(f"Checking event types in: {log_file}") + + with open(log_file_path, "r") as f: + for line_num, line in enumerate(f, 1): + try: + # Parse each line as JSON + event_data = json.loads(line.strip()) + + # Extract event_type if present + if "event_type" in event_data: + event_type = event_data["event_type"] + if event_type in event_type_counts: + event_type_counts[event_type] += 1 + print( + f" Line {line_num}: event_type = '{event_type}' (count: {event_type_counts[event_type]})" + ) + else: + print( + f" Line {line_num}: event_type = '{event_type}' (not tracked)" + ) + + except json.JSONDecodeError as e: + print(f" Line {line_num}: JSON decode error - {e}") + continue + except Exception as e: + print(f" Line {line_num}: Unexpected error - {e}") + continue + + print(f"Event type counts: {event_type_counts}") + return event_type_counts + + # Check event type counts in log files + event_counts = check_event_type_counts_in_logs(temp_dir_logs) + + # Assert specific counts for 'launch' and 'compilation' event types + assert ( + event_counts["compilation"] == 1 + ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" + assert ( + event_counts["launch"] == 2 + ), f"Expected 2 'launch' events, found {event_counts['launch']}" + print("✓ Verified correct event type counts: 1 compilation, 2 launch") + tritonparse.utils.unified_parse( source=temp_dir_logs, out=temp_dir_parsed, overwrite=True ) diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 7b38dbd..d4bd35e 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -981,13 +981,16 @@ def init_basic(trace_folder: Optional[str] = None): maybe_enable_trace_launch() -def init(trace_folder: Optional[str] = None): +def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False): """ This function is a wrapper around init_basic() that also setup the compilation listener. Args: trace_folder (Optional[str]): The folder to store the trace files. """ + global TRITON_TRACE_LAUNCH + if enable_trace_launch: + TRITON_TRACE_LAUNCH = True import triton init_basic(trace_folder) From 72fcf23a64a09fc6da7a0d00bec44cdc7518b47f Mon Sep 17 00:00:00 2001 From: FindHao Date: Mon, 14 Jul 2025 19:46:07 -0400 Subject: [PATCH 2/4] Refactor test_tritonparse.py for improved structure and clarity Summary: - Introduced helper methods to streamline kernel creation, execution, and log verification in the test suite. - Enhanced the test workflow by separating concerns into dedicated methods for setting up directories, generating test data, and verifying log outputs. - Updated the `test_whole_workflow` method to utilize the new helper methods, improving readability and maintainability. - Added assertions to verify event counts and parsing output, ensuring comprehensive testing of the unified_parse functionality. These changes aim to enhance the organization and robustness of the test suite, facilitating easier future modifications and debugging. --- tests/test_tritonparse.py | 213 ++++++++++++++++++++++---------------- 1 file changed, 121 insertions(+), 92 deletions(-) diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index 3e51ab7..c6d0b94 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -165,18 +165,10 @@ def compile_listener( assert "python_source" in trace_data assert "file_path" in trace_data["python_source"] - @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") - def test_whole_workflow(self): - """Test unified_parse functionality""" - - # Define a simple kernel directly in the test function + def _create_test_kernel(self): + """Create a simple test kernel for compilation testing""" @triton.jit - def test_kernel( - x_ptr, - y_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): + def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -185,105 +177,142 @@ def test_kernel( x = tl.load(x_ptr + offsets, mask=mask) y = x + 1.0 # Simple operation: add 1 tl.store(y_ptr + offsets, y, mask=mask) + + return test_kernel - # Simple function to run the kernel - def run_test_kernel(x): - n_elements = x.numel() - y = torch.empty_like(x) - BLOCK_SIZE = 256 # Smaller block size for simplicity - grid = (triton.cdiv(n_elements, BLOCK_SIZE),) - test_kernel[grid](x, y, n_elements, BLOCK_SIZE) - return y + def _run_kernel(self, kernel, x): + """Run a kernel with given input tensor""" + n_elements = x.numel() + y = torch.empty_like(x) + BLOCK_SIZE = 256 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + kernel[grid](x, y, n_elements, BLOCK_SIZE) + return y + def _setup_test_directories(self): + """Set up temporary directories for testing""" temp_dir = tempfile.mkdtemp() - print(f"Temporary directory: {temp_dir}") temp_dir_logs = os.path.join(temp_dir, "logs") - os.makedirs(temp_dir_logs, exist_ok=True) temp_dir_parsed = os.path.join(temp_dir, "parsed_output") + + os.makedirs(temp_dir_logs, exist_ok=True) os.makedirs(temp_dir_parsed, exist_ok=True) + + return temp_dir, temp_dir_logs, temp_dir_parsed - tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) - - # Generate some triton compilation activity to create log files + def _generate_test_data(self): + """Generate test tensor data""" torch.manual_seed(0) size = (512, 512) # Smaller size for faster testing - x = torch.randn(size, device=self.cuda_device, dtype=torch.float32) - run_test_kernel(x) # Run the simple kernel - run_test_kernel(x) # Run the simple kernel again - torch.cuda.synchronize() - - # Check that temp_dir_logs folder has content - assert os.path.exists( - temp_dir_logs - ), f"Log directory {temp_dir_logs} does not exist." - log_files = os.listdir(temp_dir_logs) - assert ( - len(log_files) > 0 - ), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation." - print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}") - - # Check that log files contain specific counts of 'launch' and 'compilation' event types - def check_event_type_counts_in_logs(log_dir): - """Check that log files contain exactly 1 'compilation' and 1 'launch' event types""" - event_type_counts = {"compilation": 0, "launch": 0} + return torch.randn(size, device=self.cuda_device, dtype=torch.float32) + + def _verify_log_directory(self, log_dir): + """Verify that log directory exists and contains files""" + assert os.path.exists(log_dir), f"Log directory {log_dir} does not exist." + + log_files = os.listdir(log_dir) + assert len(log_files) > 0, ( + f"No log files found in {log_dir}. " + "Expected log files to be generated during Triton compilation." + ) + print(f"Found {len(log_files)} log files in {log_dir}: {log_files}") + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_whole_workflow(self): + """Test unified_parse functionality""" + # Set up test environment + temp_dir, temp_dir_logs, temp_dir_parsed = self._setup_test_directories() + print(f"Temporary directory: {temp_dir}") + + # Initialize logging + tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) + + # Generate test data and run kernels + test_kernel = self._create_test_kernel() + x = self._generate_test_data() + + # Run kernel twice to generate compilation and launch events + self._run_kernel(test_kernel, x) + self._run_kernel(test_kernel, x) + torch.cuda.synchronize() + + # Verify log directory + self._verify_log_directory(temp_dir_logs) + + def parse_log_line(line: str, line_num: int) -> dict | None: + """Parse a single log line and extract event data""" + try: + return json.loads(line.strip()) + except json.JSONDecodeError as e: + print(f" Line {line_num}: JSON decode error - {e}") + return None + + def process_event_data(event_data: dict, line_num: int, event_counts: dict) -> None: + """Process event data and update counts""" + try: + event_type = event_data.get("event_type") + if event_type is None: + return + + if event_type in event_counts: + event_counts[event_type] += 1 + print(f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})") + else: + print(f" Line {line_num}: event_type = '{event_type}' (not tracked)") + except (KeyError, TypeError) as e: + print(f" Line {line_num}: Data structure error - {e}") + + def count_events_in_file(file_path: str, event_counts: dict) -> None: + """Count events in a single log file""" + print(f"Checking event types in: {os.path.basename(file_path)}") + + with open(file_path, "r") as f: + for line_num, line in enumerate(f, 1): + event_data = parse_log_line(line, line_num) + if event_data: + process_event_data(event_data, line_num, event_counts) + + def check_event_type_counts_in_logs(log_dir: str) -> dict: + """Count 'launch' and 'compilation' events in all log files""" + event_counts = {"compilation": 0, "launch": 0} + for log_file in os.listdir(log_dir): if log_file.endswith(".ndjson"): log_file_path = os.path.join(log_dir, log_file) - print(f"Checking event types in: {log_file}") - - with open(log_file_path, "r") as f: - for line_num, line in enumerate(f, 1): - try: - # Parse each line as JSON - event_data = json.loads(line.strip()) - - # Extract event_type if present - if "event_type" in event_data: - event_type = event_data["event_type"] - if event_type in event_type_counts: - event_type_counts[event_type] += 1 - print( - f" Line {line_num}: event_type = '{event_type}' (count: {event_type_counts[event_type]})" - ) - else: - print( - f" Line {line_num}: event_type = '{event_type}' (not tracked)" - ) - - except json.JSONDecodeError as e: - print(f" Line {line_num}: JSON decode error - {e}") - continue - except Exception as e: - print(f" Line {line_num}: Unexpected error - {e}") - continue - - print(f"Event type counts: {event_type_counts}") - return event_type_counts - - # Check event type counts in log files - event_counts = check_event_type_counts_in_logs(temp_dir_logs) - - # Assert specific counts for 'launch' and 'compilation' event types - assert ( - event_counts["compilation"] == 1 - ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" - assert ( - event_counts["launch"] == 2 - ), f"Expected 2 'launch' events, found {event_counts['launch']}" - print("✓ Verified correct event type counts: 1 compilation, 2 launch") + count_events_in_file(log_file_path, event_counts) + + print(f"Event type counts: {event_counts}") + return event_counts + # Verify event counts + event_counts = check_event_type_counts_in_logs(temp_dir_logs) + self._verify_event_counts(event_counts) + + # Test parsing functionality tritonparse.utils.unified_parse( source=temp_dir_logs, out=temp_dir_parsed, overwrite=True ) + + # Verify parsing output + self._verify_parsing_output(temp_dir_parsed) + + # Clean up + shutil.rmtree(temp_dir) + + def _verify_event_counts(self, event_counts): + """Verify that event counts match expected values""" + assert event_counts["compilation"] == 1, ( + f"Expected 1 'compilation' event, found {event_counts['compilation']}" + ) + assert event_counts["launch"] == 2, ( + f"Expected 2 'launch' events, found {event_counts['launch']}" + ) + print("✓ Verified correct event type counts: 1 compilation, 2 launch") - # Clean up temporary directory - try: - # Check that parsed output directory has files - parsed_files = os.listdir(temp_dir_parsed) - assert len(parsed_files) > 0, "No files found in parsed output directory" - finally: - shutil.rmtree(temp_dir) + def _verify_parsing_output(self, parsed_dir): + """Verify that parsing output directory contains files""" + parsed_files = os.listdir(parsed_dir) + assert len(parsed_files) > 0, "No files found in parsed output directory" if __name__ == "__main__": From 4b2e421faf6b36238911794097c4164bc8f854d3 Mon Sep 17 00:00:00 2001 From: FindHao Date: Mon, 14 Jul 2025 19:52:22 -0400 Subject: [PATCH 3/4] fix lint --- tests/test_tritonparse.py | 55 ++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index c6d0b94..fc717c7 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -167,6 +167,7 @@ def compile_listener( def _create_test_kernel(self): """Create a simple test kernel for compilation testing""" + @triton.jit def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) @@ -177,7 +178,7 @@ def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): x = tl.load(x_ptr + offsets, mask=mask) y = x + 1.0 # Simple operation: add 1 tl.store(y_ptr + offsets, y, mask=mask) - + return test_kernel def _run_kernel(self, kernel, x): @@ -194,10 +195,10 @@ def _setup_test_directories(self): temp_dir = tempfile.mkdtemp() temp_dir_logs = os.path.join(temp_dir, "logs") temp_dir_parsed = os.path.join(temp_dir, "parsed_output") - + os.makedirs(temp_dir_logs, exist_ok=True) os.makedirs(temp_dir_parsed, exist_ok=True) - + return temp_dir, temp_dir_logs, temp_dir_parsed def _generate_test_data(self): @@ -209,7 +210,7 @@ def _generate_test_data(self): def _verify_log_directory(self, log_dir): """Verify that log directory exists and contains files""" assert os.path.exists(log_dir), f"Log directory {log_dir} does not exist." - + log_files = os.listdir(log_dir) assert len(log_files) > 0, ( f"No log files found in {log_dir}. " @@ -223,19 +224,19 @@ def test_whole_workflow(self): # Set up test environment temp_dir, temp_dir_logs, temp_dir_parsed = self._setup_test_directories() print(f"Temporary directory: {temp_dir}") - + # Initialize logging tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) - + # Generate test data and run kernels test_kernel = self._create_test_kernel() x = self._generate_test_data() - + # Run kernel twice to generate compilation and launch events self._run_kernel(test_kernel, x) self._run_kernel(test_kernel, x) torch.cuda.synchronize() - + # Verify log directory self._verify_log_directory(temp_dir_logs) @@ -247,25 +248,31 @@ def parse_log_line(line: str, line_num: int) -> dict | None: print(f" Line {line_num}: JSON decode error - {e}") return None - def process_event_data(event_data: dict, line_num: int, event_counts: dict) -> None: + def process_event_data( + event_data: dict, line_num: int, event_counts: dict + ) -> None: """Process event data and update counts""" try: event_type = event_data.get("event_type") if event_type is None: return - + if event_type in event_counts: event_counts[event_type] += 1 - print(f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})") + print( + f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})" + ) else: - print(f" Line {line_num}: event_type = '{event_type}' (not tracked)") + print( + f" Line {line_num}: event_type = '{event_type}' (not tracked)" + ) except (KeyError, TypeError) as e: print(f" Line {line_num}: Data structure error - {e}") def count_events_in_file(file_path: str, event_counts: dict) -> None: """Count events in a single log file""" print(f"Checking event types in: {os.path.basename(file_path)}") - + with open(file_path, "r") as f: for line_num, line in enumerate(f, 1): event_data = parse_log_line(line, line_num) @@ -275,38 +282,38 @@ def count_events_in_file(file_path: str, event_counts: dict) -> None: def check_event_type_counts_in_logs(log_dir: str) -> dict: """Count 'launch' and 'compilation' events in all log files""" event_counts = {"compilation": 0, "launch": 0} - + for log_file in os.listdir(log_dir): if log_file.endswith(".ndjson"): log_file_path = os.path.join(log_dir, log_file) count_events_in_file(log_file_path, event_counts) - + print(f"Event type counts: {event_counts}") return event_counts # Verify event counts event_counts = check_event_type_counts_in_logs(temp_dir_logs) self._verify_event_counts(event_counts) - + # Test parsing functionality tritonparse.utils.unified_parse( source=temp_dir_logs, out=temp_dir_parsed, overwrite=True ) - + # Verify parsing output self._verify_parsing_output(temp_dir_parsed) - + # Clean up shutil.rmtree(temp_dir) def _verify_event_counts(self, event_counts): """Verify that event counts match expected values""" - assert event_counts["compilation"] == 1, ( - f"Expected 1 'compilation' event, found {event_counts['compilation']}" - ) - assert event_counts["launch"] == 2, ( - f"Expected 2 'launch' events, found {event_counts['launch']}" - ) + assert ( + event_counts["compilation"] == 1 + ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" + assert ( + event_counts["launch"] == 2 + ), f"Expected 2 'launch' events, found {event_counts['launch']}" print("✓ Verified correct event type counts: 1 compilation, 2 launch") def _verify_parsing_output(self, parsed_dir): From f3af4f3cbd6238c549d0c8a07d6e97bce42d133c Mon Sep 17 00:00:00 2001 From: FindHao Date: Tue, 15 Jul 2025 16:46:50 -0700 Subject: [PATCH 4/4] Refactor test_whole_workflow in test_tritonparse.py for improved clarity and functionality Summary: - Consolidated kernel creation and execution into the `test_whole_workflow` method, enhancing readability. - Removed redundant helper methods and integrated their logic directly into the test, streamlining the workflow. - Added assertions to verify the existence of log files and event counts, ensuring comprehensive testing of the unified_parse functionality. - Simplified setup and data generation processes within the test, improving maintainability. These changes aim to enhance the organization and robustness of the test suite, facilitating easier future modifications and debugging. --- tests/test_tritonparse.py | 98 ++++++++++++++------------------------- 1 file changed, 36 insertions(+), 62 deletions(-) diff --git a/tests/test_tritonparse.py b/tests/test_tritonparse.py index fc717c7..d01a060 100644 --- a/tests/test_tritonparse.py +++ b/tests/test_tritonparse.py @@ -165,9 +165,11 @@ def compile_listener( assert "python_source" in trace_data assert "file_path" in trace_data["python_source"] - def _create_test_kernel(self): - """Create a simple test kernel for compilation testing""" + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_whole_workflow(self): + """Test unified_parse functionality""" + # Define a simple kernel directly in the test function @triton.jit def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) @@ -179,66 +181,46 @@ def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): y = x + 1.0 # Simple operation: add 1 tl.store(y_ptr + offsets, y, mask=mask) - return test_kernel - - def _run_kernel(self, kernel, x): - """Run a kernel with given input tensor""" - n_elements = x.numel() - y = torch.empty_like(x) - BLOCK_SIZE = 256 - grid = (triton.cdiv(n_elements, BLOCK_SIZE),) - kernel[grid](x, y, n_elements, BLOCK_SIZE) - return y + # Simple function to run the kernel + def run_test_kernel(x): + n_elements = x.numel() + y = torch.empty_like(x) + BLOCK_SIZE = 256 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + test_kernel[grid](x, y, n_elements, BLOCK_SIZE) + return y - def _setup_test_directories(self): - """Set up temporary directories for testing""" + # Set up test environment temp_dir = tempfile.mkdtemp() temp_dir_logs = os.path.join(temp_dir, "logs") temp_dir_parsed = os.path.join(temp_dir, "parsed_output") - os.makedirs(temp_dir_logs, exist_ok=True) os.makedirs(temp_dir_parsed, exist_ok=True) - - return temp_dir, temp_dir_logs, temp_dir_parsed - - def _generate_test_data(self): - """Generate test tensor data""" - torch.manual_seed(0) - size = (512, 512) # Smaller size for faster testing - return torch.randn(size, device=self.cuda_device, dtype=torch.float32) - - def _verify_log_directory(self, log_dir): - """Verify that log directory exists and contains files""" - assert os.path.exists(log_dir), f"Log directory {log_dir} does not exist." - - log_files = os.listdir(log_dir) - assert len(log_files) > 0, ( - f"No log files found in {log_dir}. " - "Expected log files to be generated during Triton compilation." - ) - print(f"Found {len(log_files)} log files in {log_dir}: {log_files}") - - @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") - def test_whole_workflow(self): - """Test unified_parse functionality""" - # Set up test environment - temp_dir, temp_dir_logs, temp_dir_parsed = self._setup_test_directories() print(f"Temporary directory: {temp_dir}") # Initialize logging tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True) # Generate test data and run kernels - test_kernel = self._create_test_kernel() - x = self._generate_test_data() + torch.manual_seed(0) + size = (512, 512) # Smaller size for faster testing + x = torch.randn(size, device=self.cuda_device, dtype=torch.float32) # Run kernel twice to generate compilation and launch events - self._run_kernel(test_kernel, x) - self._run_kernel(test_kernel, x) + run_test_kernel(x) + run_test_kernel(x) torch.cuda.synchronize() # Verify log directory - self._verify_log_directory(temp_dir_logs) + assert os.path.exists( + temp_dir_logs + ), f"Log directory {temp_dir_logs} does not exist." + log_files = os.listdir(temp_dir_logs) + assert len(log_files) > 0, ( + f"No log files found in {temp_dir_logs}. " + "Expected log files to be generated during Triton compilation." + ) + print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}") def parse_log_line(line: str, line_num: int) -> dict | None: """Parse a single log line and extract event data""" @@ -293,7 +275,13 @@ def check_event_type_counts_in_logs(log_dir: str) -> dict: # Verify event counts event_counts = check_event_type_counts_in_logs(temp_dir_logs) - self._verify_event_counts(event_counts) + assert ( + event_counts["compilation"] == 1 + ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" + assert ( + event_counts["launch"] == 2 + ), f"Expected 2 'launch' events, found {event_counts['launch']}" + print("✓ Verified correct event type counts: 1 compilation, 2 launch") # Test parsing functionality tritonparse.utils.unified_parse( @@ -301,26 +289,12 @@ def check_event_type_counts_in_logs(log_dir: str) -> dict: ) # Verify parsing output - self._verify_parsing_output(temp_dir_parsed) + parsed_files = os.listdir(temp_dir_parsed) + assert len(parsed_files) > 0, "No files found in parsed output directory" # Clean up shutil.rmtree(temp_dir) - def _verify_event_counts(self, event_counts): - """Verify that event counts match expected values""" - assert ( - event_counts["compilation"] == 1 - ), f"Expected 1 'compilation' event, found {event_counts['compilation']}" - assert ( - event_counts["launch"] == 2 - ), f"Expected 2 'launch' events, found {event_counts['launch']}" - print("✓ Verified correct event type counts: 1 compilation, 2 launch") - - def _verify_parsing_output(self, parsed_dir): - """Verify that parsing output directory contains files""" - parsed_files = os.listdir(parsed_dir) - assert len(parsed_files) > 0, "No files found in parsed output directory" - if __name__ == "__main__": unittest.main()