diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 823e564..9dfd928 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -1,12 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import atexit +import gzip +import hashlib import importlib import inspect import json import logging import math import os +import tempfile from collections import defaultdict from dataclasses import asdict, is_dataclass from datetime import date, datetime @@ -32,7 +35,17 @@ DEFAULT_TRACE_FILE_PREFIX = ( f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_" ) +# Enable launch trace. WARNNING: it will overwrite launch_metadata for each triton kernel. +TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"] +# Enable tensor blob storage +TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in ["1", "true", "True"] +# Tensor size limit in bytes (default 10GB) +TRITONPARSE_TENSOR_SIZE_LIMIT = int(os.getenv("TRITONPARSE_TENSOR_SIZE_LIMIT", str(10 * 1024 * 1024 * 1024))) + TRITON_TRACE_HANDLER = None +# Global tensor blob manager instance +TENSOR_BLOB_MANAGER = None + if importlib.util.find_spec("torch") is not None: TORCH_INSTALLED = True import torch @@ -41,6 +54,131 @@ TORCH_INSTALLED = False +class TensorBlobManager: + """ + Manager for storing tensor data as content-addressed blobs. + + Uses BLAKE2b hashing for content addressing and stores blobs in a two-level + directory structure to avoid filesystem limitations with large numbers of files. + """ + + def __init__(self, root_dir: Optional[str] = None): + self.root_dir = None + self.hash_to_path_cache = {} # In-memory cache for hash -> path mapping + if root_dir: + self.set_root_dir(root_dir) + + def set_root_dir(self, root_dir: str): + """Set the root directory for blob storage.""" + self.root_dir = Path(root_dir) / "saved_tensors" + self.root_dir.mkdir(parents=True, exist_ok=True) + log.debug(f"TensorBlobManager: using root directory {self.root_dir}") + + def _compute_hash(self, data: bytes) -> str: + """Compute BLAKE2b hash of the data.""" + return hashlib.blake2b(data).hexdigest() + + def _get_blob_path(self, hash_hex: str) -> Path: + """Get the file path for a given hash using two-level directory structure.""" + if not self.root_dir: + raise ValueError("Root directory not set") + + # Two-level directory: first 2 chars / full_hash.bin + subdir = hash_hex[:2] + filename = f"{hash_hex}.bin" + return (self.root_dir / subdir / filename).resolve() + + def _get_tensor_size_bytes(self, tensor) -> int: + """Get tensor size in bytes before serialization.""" + if hasattr(tensor, 'numel') and hasattr(tensor, 'element_size'): + return tensor.numel() * tensor.element_size() + return 0 + + def save_tensor_blob(self, tensor) -> Dict[str, Any]: + """ + Save tensor as a blob and return metadata. + + Args: + tensor: PyTorch tensor to save + + Returns: + Dictionary with blob metadata or error information: + - Success: {'tensor_hash': str, 'blob_path': str, 'blob_size': int, 'serialization_method': str} + - Error: {'error': str, 'tensor_hash': None} + """ + if not self.root_dir: + return {'error': 'Blob storage not initialized', 'tensor_hash': None} + + try: + # Check tensor size before serialization + tensor_size = self._get_tensor_size_bytes(tensor) + if tensor_size > TRITONPARSE_TENSOR_SIZE_LIMIT: + log.warning( + f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes, skipping blob storage" + ) + return { + 'error': f'Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes', + 'tensor_hash': None + } + + # Serialize tensor using torch.save + # TODO: Consider async serialization for very large tensors to avoid blocking + import io + buffer = io.BytesIO() + if TORCH_INSTALLED: + torch.save(tensor.cpu(), buffer) + else: + return {'error': 'PyTorch not available for tensor serialization', 'tensor_hash': None} + + blob_data = buffer.getvalue() + hash_hex = self._compute_hash(blob_data) + + # Check if we already have this blob + if hash_hex in self.hash_to_path_cache: + blob_path = self.hash_to_path_cache[hash_hex] + if blob_path.exists(): + return { + 'tensor_hash': hash_hex, + 'blob_path': str(blob_path), + 'blob_size': len(blob_data), + 'serialization_method': 'torch_save' + } + + # Create blob file + blob_path = self._get_blob_path(hash_hex) + blob_path.parent.mkdir(parents=True, exist_ok=True) + + # Atomic write using temporary file + rename + with tempfile.NamedTemporaryFile( + mode='wb', + dir=blob_path.parent, + prefix=f".tmp_{hash_hex}_", + delete=False + ) as tmp_file: + tmp_file.write(blob_data) + tmp_path = Path(tmp_file.name) + + # Atomic rename + tmp_path.rename(blob_path) + + # Cache the path + self.hash_to_path_cache[hash_hex] = blob_path + + log.debug(f"Saved tensor blob: {hash_hex} -> {blob_path}") + + return { + 'tensor_hash': hash_hex, + 'blob_path': str(blob_path), + 'blob_size': len(blob_data), + 'serialization_method': 'torch_save' + } + + except Exception as e: + error_msg = f"Failed to save tensor blob: {str(e)}" + log.error(error_msg) + return {'error': error_msg, 'tensor_hash': None} + + class TritonLogRecord(logging.LogRecord): """ Custom LogRecord class for structured logging of Triton operations. @@ -461,7 +599,7 @@ def init_logs(): DEBUG:tritonparse_trace: lines by blocking propagation to the root logger. """ - global TRITON_TRACE_HANDLER, triton_trace_folder + global TRITON_TRACE_HANDLER, triton_trace_folder, TENSOR_BLOB_MANAGER # Basic logger settings (safe to run on every call) triton_trace_log.setLevel(logging.DEBUG) @@ -486,6 +624,15 @@ def init_logs(): if TRITON_TRACE_HANDLER not in triton_trace_log.handlers: TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter()) triton_trace_log.addHandler(TRITON_TRACE_HANDLER) + + # Initialize tensor blob manager if enabled + if TRITONPARSE_SAVE_TENSOR_BLOBS: + if TENSOR_BLOB_MANAGER is None: + TENSOR_BLOB_MANAGER = TensorBlobManager() + + # Set or update root directory for blob storage + if root_dir and TENSOR_BLOB_MANAGER.root_dir is None: + TENSOR_BLOB_MANAGER.set_root_dir(root_dir) def trace_structured_triton( @@ -613,6 +760,182 @@ def maybe_trace_triton( return trace_data +from triton.knobs import LaunchHook, JITHook + + +def extract_arg_info(arg_dict): + """ + Extract detailed information from kernel arguments, especially for PyTorch tensors. + + Args: + arg_dict: Dictionary of kernel arguments + + Returns: + Dictionary with extracted argument information including tensor properties + """ + global TENSOR_BLOB_MANAGER + + extracted_args = {} + + for arg_name, arg_value in arg_dict.items(): + arg_info = {} + + # Check if it's a PyTorch tensor + if hasattr(arg_value, 'shape') and hasattr(arg_value, 'dtype'): + arg_info['type'] = 'tensor' + arg_info['shape'] = list(arg_value.shape) + arg_info['dtype'] = str(arg_value.dtype) + arg_info['device'] = str(arg_value.device) + arg_info['stride'] = list(arg_value.stride()) + arg_info['numel'] = arg_value.numel() + arg_info['is_contiguous'] = arg_value.is_contiguous() + arg_info['element_size'] = arg_value.element_size() + arg_info['storage_offset'] = arg_value.storage_offset() + # Memory usage in bytes + arg_info['memory_usage'] = arg_value.numel() * arg_value.element_size() + # Add data_ptr for memory tracking (optional) + if hasattr(arg_value, 'data_ptr'): + arg_info['data_ptr'] = hex(arg_value.data_ptr()) + + # Add tensor blob storage if enabled + if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None: + blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(arg_value) + arg_info.update(blob_info) + + # Handle scalar values + elif isinstance(arg_value, (int, float, bool)): + arg_info['type'] = type(arg_value).__name__ + arg_info['value'] = arg_value + # Handle strings + elif isinstance(arg_value, str): + arg_info['type'] = 'str' + arg_info['value'] = arg_value + arg_info['length'] = len(arg_value) + # Handle other types + else: + arg_info['type'] = type(arg_value).__name__ + # Try to convert to string for logging, but be safe about it + try: + arg_info['repr'] = str(arg_value) + if len(arg_info['repr']) > 200: # Truncate very long representations + arg_info['repr'] = arg_info['repr'][:200] + "..." + except: + arg_info['repr'] = f"<{type(arg_value).__name__} object>" + + extracted_args[arg_name] = arg_info + + return extracted_args + + +def add_launch_metadata(grid, metadata, arg_dict): + # Extract detailed argument information + extracted_args = extract_arg_info(arg_dict) + return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + + +class JITHookImpl(JITHook): + """ + JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels. + + This hook is essential for capturing detailed kernel launch information beyond the basic + metadata (like kernel name) that Triton provides by default. Without setting a custom + launch_metadata function, only minimal launch information is available as shown in: + https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504 + + By intercepting the JIT compilation process and setting a custom launch_metadata function, + we can capture comprehensive runtime information including grid parameters, kernel metadata, + and argument dictionaries for detailed analysis and logging. + """ + + def __call__( + self, + *, + key: str, + repr: str, + fn, + compile, + is_manual_warmup: bool, + already_compiled: bool, + ) -> Optional[bool]: + """ + Override or set the launch_metadata function for the JIT-compiled kernel. + + This method is called during the JIT compilation process and allows us to + inject our custom launch_metadata function that will be used to collect + detailed kernel launch information. + + Args: + key: Unique identifier for the kernel + repr: String representation of the kernel + fn: The JIT function object + compile: Compilation function + is_manual_warmup: Whether this is a manual warmup call + already_compiled: Whether the kernel is already compiled + + Returns: + True to continue with compilation, None/False to skip + """ + launch_metadata_fn = fn.jit_function.launch_metadata + if launch_metadata_fn is not None: + log.warning( + f"fn {fn} launch_metadata_fn is not None: {launch_metadata_fn}. It will be overridden by tritonparse." + ) + fn.jit_function.launch_metadata = add_launch_metadata + return True + + +class LaunchHookImpl(LaunchHook): + """ + Launch Hook implementation for capturing and logging kernel launch metadata. + + This hook is responsible for intercepting kernel launches and extracting the detailed + metadata that was set up by the JITHookImpl. It provides entry point for + kernel execution, allowing comprehensive logging and analysis of kernel launches + including timing, parameters, and execution context. + + The metadata captured includes: + - Kernel name and function details + - Grid dimensions and launch parameters + - Kernel arguments and their values + - Stream information + - Custom metadata added by the launch_metadata function + """ + + def enter(self, metadata): + """ + Handle kernel launch entry point. + + This method is called when a kernel is about to be launched, providing + access to all the launch metadata for logging, profiling, or analysis. + metadata format: + + Args: + metadata: LazyDict containing comprehensive launch information including + kernel name, function, stream, grid parameters, and custom data + format: {'name': 'add_kernel', 'function': None, 'stream': 0, + 'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)} + where extracted_args contains detailed info for each argument: + - For tensors: shape, dtype, device, stride, memory_usage, etc. + - For scalars: type and value + - For other types: type and string representation + defined here: + https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/ + python/triton/compiler/compiler.py#L512. + """ + trace_data = defaultdict(dict) + metadata_dict = metadata.get() + trace_data["name"] = metadata_dict["name"] + trace_data["function"] = metadata_dict["function"] + trace_data["stream"] = metadata_dict["stream"] + launch_metadata_tritonparse = metadata_dict.get("launch_metadata_tritonparse", None) + if launch_metadata_tritonparse is not None: + trace_data["grid"] = launch_metadata_tritonparse[0] + trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["extracted_args"] = launch_metadata_tritonparse[2] # Now contains detailed arg info + trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data)) + + + def init(trace_folder: Optional[str] = None): """ Initialize the structured logging system for Triton compilation.