Skip to content

Enhance structured logging initialization and test coverage #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
```
"""

import os

import torch
import triton
import triton.language as tl
Expand All @@ -15,7 +17,9 @@
import tritonparse.utils

log_path = "./logs"
tritonparse.structured_logging.init(log_path)
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)

os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"


@triton.jit
Expand Down
109 changes: 86 additions & 23 deletions tests/test_tritonparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
```
"""

import json
import os
import shutil
import tempfile
Expand Down Expand Up @@ -170,12 +171,7 @@ def test_whole_workflow(self):

# Define a simple kernel directly in the test function
@triton.jit
def test_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
def test_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
Expand All @@ -189,48 +185,115 @@ def test_kernel(
def run_test_kernel(x):
n_elements = x.numel()
y = torch.empty_like(x)
BLOCK_SIZE = 256 # Smaller block size for simplicity
BLOCK_SIZE = 256
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
test_kernel[grid](x, y, n_elements, BLOCK_SIZE)
return y

# Set up test environment
temp_dir = tempfile.mkdtemp()
print(f"Temporary directory: {temp_dir}")
temp_dir_logs = os.path.join(temp_dir, "logs")
os.makedirs(temp_dir_logs, exist_ok=True)
temp_dir_parsed = os.path.join(temp_dir, "parsed_output")
os.makedirs(temp_dir_logs, exist_ok=True)
os.makedirs(temp_dir_parsed, exist_ok=True)
print(f"Temporary directory: {temp_dir}")

tritonparse.structured_logging.init(temp_dir_logs)
# Initialize logging
tritonparse.structured_logging.init(temp_dir_logs, enable_trace_launch=True)

# Generate some triton compilation activity to create log files
# Generate test data and run kernels
torch.manual_seed(0)
size = (512, 512) # Smaller size for faster testing
x = torch.randn(size, device=self.cuda_device, dtype=torch.float32)
run_test_kernel(x) # Run the simple kernel

# Run kernel twice to generate compilation and launch events
run_test_kernel(x)
run_test_kernel(x)
torch.cuda.synchronize()

# Check that temp_dir_logs folder has content
# Verify log directory
assert os.path.exists(
temp_dir_logs
), f"Log directory {temp_dir_logs} does not exist."
log_files = os.listdir(temp_dir_logs)
assert (
len(log_files) > 0
), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation."
assert len(log_files) > 0, (
f"No log files found in {temp_dir_logs}. "
"Expected log files to be generated during Triton compilation."
)
print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}")

def parse_log_line(line: str, line_num: int) -> dict | None:
"""Parse a single log line and extract event data"""
try:
return json.loads(line.strip())
except json.JSONDecodeError as e:
print(f" Line {line_num}: JSON decode error - {e}")
return None

def process_event_data(
event_data: dict, line_num: int, event_counts: dict
) -> None:
"""Process event data and update counts"""
try:
event_type = event_data.get("event_type")
if event_type is None:
return

if event_type in event_counts:
event_counts[event_type] += 1
print(
f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})"
)
else:
print(
f" Line {line_num}: event_type = '{event_type}' (not tracked)"
)
except (KeyError, TypeError) as e:
print(f" Line {line_num}: Data structure error - {e}")

def count_events_in_file(file_path: str, event_counts: dict) -> None:
"""Count events in a single log file"""
print(f"Checking event types in: {os.path.basename(file_path)}")

with open(file_path, "r") as f:
for line_num, line in enumerate(f, 1):
event_data = parse_log_line(line, line_num)
if event_data:
process_event_data(event_data, line_num, event_counts)

def check_event_type_counts_in_logs(log_dir: str) -> dict:
"""Count 'launch' and 'compilation' events in all log files"""
event_counts = {"compilation": 0, "launch": 0}

for log_file in os.listdir(log_dir):
if log_file.endswith(".ndjson"):
log_file_path = os.path.join(log_dir, log_file)
count_events_in_file(log_file_path, event_counts)

print(f"Event type counts: {event_counts}")
return event_counts

# Verify event counts
event_counts = check_event_type_counts_in_logs(temp_dir_logs)
assert (
event_counts["compilation"] == 1
), f"Expected 1 'compilation' event, found {event_counts['compilation']}"
assert (
event_counts["launch"] == 2
), f"Expected 2 'launch' events, found {event_counts['launch']}"
print("✓ Verified correct event type counts: 1 compilation, 2 launch")

# Test parsing functionality
tritonparse.utils.unified_parse(
source=temp_dir_logs, out=temp_dir_parsed, overwrite=True
)

# Clean up temporary directory
try:
# Check that parsed output directory has files
parsed_files = os.listdir(temp_dir_parsed)
assert len(parsed_files) > 0, "No files found in parsed output directory"
finally:
shutil.rmtree(temp_dir)
# Verify parsing output
parsed_files = os.listdir(temp_dir_parsed)
assert len(parsed_files) > 0, "No files found in parsed output directory"

# Clean up
shutil.rmtree(temp_dir)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion tritonparse/structured_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,13 +981,16 @@ def init_basic(trace_folder: Optional[str] = None):
maybe_enable_trace_launch()


def init(trace_folder: Optional[str] = None):
def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
"""
This function is a wrapper around init_basic() that also setup the compilation listener.

Args:
trace_folder (Optional[str]): The folder to store the trace files.
"""
global TRITON_TRACE_LAUNCH
if enable_trace_launch:
TRITON_TRACE_LAUNCH = True
import triton

init_basic(trace_folder)
Expand Down