Skip to content

02. Usage Guide

Yueming Hao edited this page Jul 10, 2025 · 4 revisions

This guide walks you through the complete workflow of using TritonParse to analyze Triton kernel compilation processes.

πŸ“‹ Overview

TritonParse workflow consists of three main steps:

  1. Generate Traces - Capture Triton compilation events
  2. Parse Traces - Process raw logs into structured format
  3. Analyze Results - Visualize and explore using the web interface

πŸš€ Step 1: Generate Triton Trace Files

Basic Setup

First, integrate TritonParse into your Triton/PyTorch code:

import torch
import triton
import triton.language as tl

# === TritonParse initialization ===
import tritonparse.structured_logging

# Initialize structured logging to capture Triton compilation events
log_path = "./logs/"
tritonparse.structured_logging.init(log_path)
# === End TritonParse initialization ===

# Your original Triton/PyTorch code below...

Example: Complete Triton Kernel

Here's a complete example showing how to instrument a Triton kernel:

import torch
import triton
import triton.language as tl
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path)

@triton.jit
def add_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def tensor_add(a, b):
    n_elements = a.numel()
    c = torch.empty_like(a)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)
    return c

# Example usage
if __name__ == "__main__":
    # Create test tensors
    device = "cuda" if torch.cuda.is_available() else "cpu"
    a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    
    # Execute kernel (this will be traced)
    c = tensor_add(a, b)
    
    # Parse the generated logs
    tritonparse.utils.unified_parse(
        source=log_path, 
        out="./parsed_output", 
        overwrite=True
    )

PyTorch 2.0 Integration

For PyTorch 2.0 compiled functions:

import torch
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path)

def simple_add(a, b):
    return a + b

# Test with torch.compile
compiled_add = torch.compile(simple_add)

# Create test data
device = "cuda"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)

# Execute compiled function (this will be traced)
result = compiled_add(a, b)

# Parse the generated logs
tritonparse.utils.unified_parse(
    source=log_path, 
    out="./parsed_output", 
    overwrite=True
)

Important Environment Variables

Set these before running your code:

# Disable FX graph cache to ensure compilation happens every time
export TORCHINDUCTOR_FX_GRAPH_CACHE=0

# Enable debug logging (optional)
export TRITONPARSE_DEBUG=1

# Enable NDJSON output (default)
export TRITONPARSE_NDJSON=1

# Enable gzip compression for trace files (optional)
export TRITON_TRACE_GZIP=1

Running the Code

# Run your instrumented code
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python your_script.py

Expected Output:

Triton kernel executed successfully
Torch compiled function executed successfully
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output

================================================================================
πŸ“ TRITONPARSE PARSING RESULTS
================================================================================
πŸ“‚ Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
πŸ“Š Total files generated: 2

πŸ“„ Generated files:
--------------------------------------------------
   1. πŸ“ dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
   2. πŸ“ log_file_list.json (181B)
================================================================================
βœ… Parsing completed successfully!
================================================================================

πŸ”§ Step 2: Parse Trace Files

Using unified_parse

The unified_parse function processes raw logs into structured format:

import tritonparse.utils

# Parse logs from directory
tritonparse.utils.unified_parse(
    source="./logs/",           # Input directory with raw logs
    out="./parsed_output",      # Output directory for processed files
    overwrite=True              # Overwrite existing output directory
)

Advanced Parsing Options

# Parse with additional options
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    overwrite=True,
    rank=0,                     # Analyze specific rank (for multi-GPU)
    all_ranks=False,            # Analyze all ranks
    verbose=True                # Enable verbose logging
)

Understanding the Output

After parsing, you'll have:

parsed_output/
β”œβ”€β”€ f0_fc0_a0_cai-.ndjson.gz          # Compressed trace
β”œβ”€β”€ dedicated_log_triton_trace_findhao__mapped.ndjson.gz   # Another trace
β”œβ”€β”€ ...
└── log_file_list.json        # Index of all generated files (optional)

Each ndjson.gz file contains:

  • Kernel metadata (grid size, block size, etc.)
  • All IR stages (TTGIR, TTIR, LLIR, PTX, AMDGCN)
  • Source mappings between IR stages
  • Compilation stack traces

Command Line Usage

You can also use the command line interface:

# Basic usage
python run.py ./logs/ -o ./parsed_output

# With options
python run.py ./logs/ -o ./parsed_output --overwrite --verbose

# Parse specific rank
python run.py ./logs/ -o ./parsed_output --rank 0

# Parse all ranks
python run.py ./logs/ -o ./parsed_output --all-ranks

🌐 Step 3: Analyze with Web Interface

Option A: Online Interface (Recommended)

  1. Visit the live tool: https://pytorch-labs.github.io/tritonparse/

  2. Load your trace files:

    • Click "Browse Files" or drag-and-drop
    • Select .gz files from your parsed_output directory
    • Or select .ndjson files from your logs directory
  3. Explore the visualization:

    • Overview Tab: Kernel metadata, call stack, IR links
    • Comparison Tab: Side-by-side IR comparison with line mapping

Option B: Local Development Interface

For contributors or custom deployments:

cd website
npm install
npm run dev

Access at http://localhost:5173

Supported File Formats

Format Description Source Mapping Recommended
.gz Compressed parsed traces βœ… Yes βœ… Yes
.ndjson Raw trace logs ❌ No ⚠️ Basic use only

Note: .ndjson files don't contain source code mappings between IR stages. Always use .gz files for full functionality.

πŸ“Š Understanding the Results

Kernel Overview

The overview page shows:

  • Kernel Information: Name, hash, grid/block sizes
  • Compilation Metadata: Device, compile time, memory usage
  • Call Stack: Python source code that triggered compilation
  • IR Navigation: Links to different IR representations

Code Comparison

The comparison view offers:

  • Side-by-side IR viewing: Compare different compilation stages
  • Synchronized highlighting: Click a line to see corresponding lines in other IRs
  • Source mapping: Trace transformations across compilation pipeline

IR Stages Explained

Stage Description When Generated
TTGIR Triton GPU IR - High-level GPU operations After Triton frontend
TTIR Triton IR - Language-level operations After parsing
LLIR LLVM IR - Low-level operations After LLVM conversion
PTX NVIDIA PTX Assembly For NVIDIA GPUs
AMDGCN AMD GPU Assembly For AMD GPUs

🎯 Common Use Cases

Understanding Compilation Pipeline

# Trace a simple kernel to understand compilation stages
@triton.jit
def simple_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = x * 2.0  # Simple operation
    tl.store(y_ptr + offsets, y, mask=mask)

# Trace and analyze each compilation stage

πŸ” Advanced Features

Filtering Kernels

Set kernel allowlist to trace only specific kernels:

# Only trace kernels matching these patterns
export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*,important_*"

Multi-GPU Analysis

For multi-GPU setups:

# Parse all ranks
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    all_ranks=True  # Analyze all GPU ranks
)

# Or parse specific rank
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    rank=1  # Analyze GPU rank 1
)

Launch Tracing

Enable launch metadata tracing:

# Enable launch tracing (experimental)
export TRITON_TRACE_LAUNCH=1

TODO: add more details about launch tracing.

πŸ› Troubleshooting

Common Issues

1. No Kernels Found

Error: No kernels found in the processed data

Solutions:

  • Ensure TORCHINDUCTOR_FX_GRAPH_CACHE=0 is set
  • Check that your kernel actually executes
  • Verify Triton is properly installed

2. Empty Log Files

Warning: Empty log directory

Solutions:

  • Ensure tritonparse.structured_logging.init() is called before kernel execution
  • Check that your code path actually executes Triton kernels
  • Verify log directory permissions

3. Source Mapping Warnings

WARNING:SourceMapping:No frame_id or frame_compile_id found in the payload.

Solutions:

  • This is often normal for PyTorch 2.0 compiled functions
  • Check that parsing completed successfully

4. Web Interface Issues

Error: Failed to load trace file

Solutions:

  • Ensure you're using .gz files from parsed_output
  • Check file size limits (browser dependent)
  • Try with a smaller trace file first

Debug Tips

  1. Enable verbose logging:

    export TRITONPARSE_DEBUG=1
  2. Check log file contents:

    ls -la ./logs/
    head -n 5 ./logs/*.ndjson
  3. Verify parsing output:

    ls -la ./parsed_output/
    zcat ./parsed_output/*.gz | head -n 10

πŸ”— Next Steps

After successfully generating and analyzing traces:

  1. Learn the Web Interface: Read the Web Interface Guide
  2. Explore Advanced Features: Check Advanced Examples
  3. Understand File Formats: See File Formats documentation
  4. Get Help: Visit our FAQ or GitHub Discussions

πŸ“š Related Documentation