From e24444981d76c5319ce8eee9c1c8ea92af9d81f2 Mon Sep 17 00:00:00 2001 From: FindHao Date: Mon, 23 Jun 2025 21:22:47 -0700 Subject: [PATCH] Add TensorBlobManager for efficient tensor storage Summary: - Introduced TensorBlobManager class to manage tensor data as content-addressed blobs using BLAKE2b hashing. - Implemented methods for saving tensors, computing hashes, and managing blob storage paths. - Added environment variables to enable tensor blob storage and set size limits. - Updated logging initialization to incorporate the tensor blob manager if enabled. - Enhanced argument extraction to include tensor blob metadata when applicable. This update improves the handling of tensor data, allowing for efficient storage and retrieval while maintaining compatibility with existing logging mechanisms. --- tritonparse/structured_logging.py | 325 +++++++++++++++++++++++++++++- 1 file changed, 324 insertions(+), 1 deletion(-) diff --git a/tritonparse/structured_logging.py b/tritonparse/structured_logging.py index 823e564..9dfd928 100644 --- a/tritonparse/structured_logging.py +++ b/tritonparse/structured_logging.py @@ -1,12 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import atexit +import gzip +import hashlib import importlib import inspect import json import logging import math import os +import tempfile from collections import defaultdict from dataclasses import asdict, is_dataclass from datetime import date, datetime @@ -32,7 +35,17 @@ DEFAULT_TRACE_FILE_PREFIX = ( f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_" ) +# Enable launch trace. WARNNING: it will overwrite launch_metadata for each triton kernel. +TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"] +# Enable tensor blob storage +TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in ["1", "true", "True"] +# Tensor size limit in bytes (default 10GB) +TRITONPARSE_TENSOR_SIZE_LIMIT = int(os.getenv("TRITONPARSE_TENSOR_SIZE_LIMIT", str(10 * 1024 * 1024 * 1024))) + TRITON_TRACE_HANDLER = None +# Global tensor blob manager instance +TENSOR_BLOB_MANAGER = None + if importlib.util.find_spec("torch") is not None: TORCH_INSTALLED = True import torch @@ -41,6 +54,131 @@ TORCH_INSTALLED = False +class TensorBlobManager: + """ + Manager for storing tensor data as content-addressed blobs. + + Uses BLAKE2b hashing for content addressing and stores blobs in a two-level + directory structure to avoid filesystem limitations with large numbers of files. + """ + + def __init__(self, root_dir: Optional[str] = None): + self.root_dir = None + self.hash_to_path_cache = {} # In-memory cache for hash -> path mapping + if root_dir: + self.set_root_dir(root_dir) + + def set_root_dir(self, root_dir: str): + """Set the root directory for blob storage.""" + self.root_dir = Path(root_dir) / "saved_tensors" + self.root_dir.mkdir(parents=True, exist_ok=True) + log.debug(f"TensorBlobManager: using root directory {self.root_dir}") + + def _compute_hash(self, data: bytes) -> str: + """Compute BLAKE2b hash of the data.""" + return hashlib.blake2b(data).hexdigest() + + def _get_blob_path(self, hash_hex: str) -> Path: + """Get the file path for a given hash using two-level directory structure.""" + if not self.root_dir: + raise ValueError("Root directory not set") + + # Two-level directory: first 2 chars / full_hash.bin + subdir = hash_hex[:2] + filename = f"{hash_hex}.bin" + return (self.root_dir / subdir / filename).resolve() + + def _get_tensor_size_bytes(self, tensor) -> int: + """Get tensor size in bytes before serialization.""" + if hasattr(tensor, 'numel') and hasattr(tensor, 'element_size'): + return tensor.numel() * tensor.element_size() + return 0 + + def save_tensor_blob(self, tensor) -> Dict[str, Any]: + """ + Save tensor as a blob and return metadata. + + Args: + tensor: PyTorch tensor to save + + Returns: + Dictionary with blob metadata or error information: + - Success: {'tensor_hash': str, 'blob_path': str, 'blob_size': int, 'serialization_method': str} + - Error: {'error': str, 'tensor_hash': None} + """ + if not self.root_dir: + return {'error': 'Blob storage not initialized', 'tensor_hash': None} + + try: + # Check tensor size before serialization + tensor_size = self._get_tensor_size_bytes(tensor) + if tensor_size > TRITONPARSE_TENSOR_SIZE_LIMIT: + log.warning( + f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes, skipping blob storage" + ) + return { + 'error': f'Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes', + 'tensor_hash': None + } + + # Serialize tensor using torch.save + # TODO: Consider async serialization for very large tensors to avoid blocking + import io + buffer = io.BytesIO() + if TORCH_INSTALLED: + torch.save(tensor.cpu(), buffer) + else: + return {'error': 'PyTorch not available for tensor serialization', 'tensor_hash': None} + + blob_data = buffer.getvalue() + hash_hex = self._compute_hash(blob_data) + + # Check if we already have this blob + if hash_hex in self.hash_to_path_cache: + blob_path = self.hash_to_path_cache[hash_hex] + if blob_path.exists(): + return { + 'tensor_hash': hash_hex, + 'blob_path': str(blob_path), + 'blob_size': len(blob_data), + 'serialization_method': 'torch_save' + } + + # Create blob file + blob_path = self._get_blob_path(hash_hex) + blob_path.parent.mkdir(parents=True, exist_ok=True) + + # Atomic write using temporary file + rename + with tempfile.NamedTemporaryFile( + mode='wb', + dir=blob_path.parent, + prefix=f".tmp_{hash_hex}_", + delete=False + ) as tmp_file: + tmp_file.write(blob_data) + tmp_path = Path(tmp_file.name) + + # Atomic rename + tmp_path.rename(blob_path) + + # Cache the path + self.hash_to_path_cache[hash_hex] = blob_path + + log.debug(f"Saved tensor blob: {hash_hex} -> {blob_path}") + + return { + 'tensor_hash': hash_hex, + 'blob_path': str(blob_path), + 'blob_size': len(blob_data), + 'serialization_method': 'torch_save' + } + + except Exception as e: + error_msg = f"Failed to save tensor blob: {str(e)}" + log.error(error_msg) + return {'error': error_msg, 'tensor_hash': None} + + class TritonLogRecord(logging.LogRecord): """ Custom LogRecord class for structured logging of Triton operations. @@ -461,7 +599,7 @@ def init_logs(): DEBUG:tritonparse_trace: lines by blocking propagation to the root logger. """ - global TRITON_TRACE_HANDLER, triton_trace_folder + global TRITON_TRACE_HANDLER, triton_trace_folder, TENSOR_BLOB_MANAGER # Basic logger settings (safe to run on every call) triton_trace_log.setLevel(logging.DEBUG) @@ -486,6 +624,15 @@ def init_logs(): if TRITON_TRACE_HANDLER not in triton_trace_log.handlers: TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter()) triton_trace_log.addHandler(TRITON_TRACE_HANDLER) + + # Initialize tensor blob manager if enabled + if TRITONPARSE_SAVE_TENSOR_BLOBS: + if TENSOR_BLOB_MANAGER is None: + TENSOR_BLOB_MANAGER = TensorBlobManager() + + # Set or update root directory for blob storage + if root_dir and TENSOR_BLOB_MANAGER.root_dir is None: + TENSOR_BLOB_MANAGER.set_root_dir(root_dir) def trace_structured_triton( @@ -613,6 +760,182 @@ def maybe_trace_triton( return trace_data +from triton.knobs import LaunchHook, JITHook + + +def extract_arg_info(arg_dict): + """ + Extract detailed information from kernel arguments, especially for PyTorch tensors. + + Args: + arg_dict: Dictionary of kernel arguments + + Returns: + Dictionary with extracted argument information including tensor properties + """ + global TENSOR_BLOB_MANAGER + + extracted_args = {} + + for arg_name, arg_value in arg_dict.items(): + arg_info = {} + + # Check if it's a PyTorch tensor + if hasattr(arg_value, 'shape') and hasattr(arg_value, 'dtype'): + arg_info['type'] = 'tensor' + arg_info['shape'] = list(arg_value.shape) + arg_info['dtype'] = str(arg_value.dtype) + arg_info['device'] = str(arg_value.device) + arg_info['stride'] = list(arg_value.stride()) + arg_info['numel'] = arg_value.numel() + arg_info['is_contiguous'] = arg_value.is_contiguous() + arg_info['element_size'] = arg_value.element_size() + arg_info['storage_offset'] = arg_value.storage_offset() + # Memory usage in bytes + arg_info['memory_usage'] = arg_value.numel() * arg_value.element_size() + # Add data_ptr for memory tracking (optional) + if hasattr(arg_value, 'data_ptr'): + arg_info['data_ptr'] = hex(arg_value.data_ptr()) + + # Add tensor blob storage if enabled + if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None: + blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(arg_value) + arg_info.update(blob_info) + + # Handle scalar values + elif isinstance(arg_value, (int, float, bool)): + arg_info['type'] = type(arg_value).__name__ + arg_info['value'] = arg_value + # Handle strings + elif isinstance(arg_value, str): + arg_info['type'] = 'str' + arg_info['value'] = arg_value + arg_info['length'] = len(arg_value) + # Handle other types + else: + arg_info['type'] = type(arg_value).__name__ + # Try to convert to string for logging, but be safe about it + try: + arg_info['repr'] = str(arg_value) + if len(arg_info['repr']) > 200: # Truncate very long representations + arg_info['repr'] = arg_info['repr'][:200] + "..." + except: + arg_info['repr'] = f"<{type(arg_value).__name__} object>" + + extracted_args[arg_name] = arg_info + + return extracted_args + + +def add_launch_metadata(grid, metadata, arg_dict): + # Extract detailed argument information + extracted_args = extract_arg_info(arg_dict) + return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)} + + +class JITHookImpl(JITHook): + """ + JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels. + + This hook is essential for capturing detailed kernel launch information beyond the basic + metadata (like kernel name) that Triton provides by default. Without setting a custom + launch_metadata function, only minimal launch information is available as shown in: + https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504 + + By intercepting the JIT compilation process and setting a custom launch_metadata function, + we can capture comprehensive runtime information including grid parameters, kernel metadata, + and argument dictionaries for detailed analysis and logging. + """ + + def __call__( + self, + *, + key: str, + repr: str, + fn, + compile, + is_manual_warmup: bool, + already_compiled: bool, + ) -> Optional[bool]: + """ + Override or set the launch_metadata function for the JIT-compiled kernel. + + This method is called during the JIT compilation process and allows us to + inject our custom launch_metadata function that will be used to collect + detailed kernel launch information. + + Args: + key: Unique identifier for the kernel + repr: String representation of the kernel + fn: The JIT function object + compile: Compilation function + is_manual_warmup: Whether this is a manual warmup call + already_compiled: Whether the kernel is already compiled + + Returns: + True to continue with compilation, None/False to skip + """ + launch_metadata_fn = fn.jit_function.launch_metadata + if launch_metadata_fn is not None: + log.warning( + f"fn {fn} launch_metadata_fn is not None: {launch_metadata_fn}. It will be overridden by tritonparse." + ) + fn.jit_function.launch_metadata = add_launch_metadata + return True + + +class LaunchHookImpl(LaunchHook): + """ + Launch Hook implementation for capturing and logging kernel launch metadata. + + This hook is responsible for intercepting kernel launches and extracting the detailed + metadata that was set up by the JITHookImpl. It provides entry point for + kernel execution, allowing comprehensive logging and analysis of kernel launches + including timing, parameters, and execution context. + + The metadata captured includes: + - Kernel name and function details + - Grid dimensions and launch parameters + - Kernel arguments and their values + - Stream information + - Custom metadata added by the launch_metadata function + """ + + def enter(self, metadata): + """ + Handle kernel launch entry point. + + This method is called when a kernel is about to be launched, providing + access to all the launch metadata for logging, profiling, or analysis. + metadata format: + + Args: + metadata: LazyDict containing comprehensive launch information including + kernel name, function, stream, grid parameters, and custom data + format: {'name': 'add_kernel', 'function': None, 'stream': 0, + 'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)} + where extracted_args contains detailed info for each argument: + - For tensors: shape, dtype, device, stride, memory_usage, etc. + - For scalars: type and value + - For other types: type and string representation + defined here: + https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/ + python/triton/compiler/compiler.py#L512. + """ + trace_data = defaultdict(dict) + metadata_dict = metadata.get() + trace_data["name"] = metadata_dict["name"] + trace_data["function"] = metadata_dict["function"] + trace_data["stream"] = metadata_dict["stream"] + launch_metadata_tritonparse = metadata_dict.get("launch_metadata_tritonparse", None) + if launch_metadata_tritonparse is not None: + trace_data["grid"] = launch_metadata_tritonparse[0] + trace_data["metadata"] = launch_metadata_tritonparse[1] + trace_data["extracted_args"] = launch_metadata_tritonparse[2] # Now contains detailed arg info + trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data)) + + + def init(trace_folder: Optional[str] = None): """ Initialize the structured logging system for Triton compilation.