|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Simple tensor loading utility for tritonparse saved tensors. |
| 4 | +Usage: |
| 5 | +import tritonparse.tools.load_tensor as load_tensor |
| 6 | +tensor = load_tensor.load_tensor(tensor_file_path, device) |
| 7 | +""" |
| 8 | + |
| 9 | +import hashlib |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | + |
| 15 | +def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor: |
| 16 | + """ |
| 17 | + Load a tensor from its file path and verify its integrity using the hash in the filename. |
| 18 | +
|
| 19 | + Args: |
| 20 | + tensor_file_path (str): Direct path to the tensor .bin file. The filename should be |
| 21 | + the hash of the file contents followed by .bin extension. |
| 22 | + device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu'). |
| 23 | + If None, keeps the tensor on its original device. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + torch.Tensor: The loaded tensor (moved to the specified device if provided) |
| 27 | +
|
| 28 | + Raises: |
| 29 | + FileNotFoundError: If the tensor file doesn't exist |
| 30 | + RuntimeError: If the tensor cannot be loaded |
| 31 | + ValueError: If the computed hash doesn't match the filename hash |
| 32 | + """ |
| 33 | + blob_path = Path(tensor_file_path) |
| 34 | + |
| 35 | + if not blob_path.exists(): |
| 36 | + raise FileNotFoundError(f"Tensor blob not found: {blob_path}") |
| 37 | + |
| 38 | + # Extract expected hash from filename (remove .bin extension) |
| 39 | + expected_hash = blob_path.stem |
| 40 | + |
| 41 | + # Compute actual hash of file contents |
| 42 | + with open(blob_path, "rb") as f: |
| 43 | + file_contents = f.read() |
| 44 | + computed_hash = hashlib.blake2b(file_contents).hexdigest() |
| 45 | + |
| 46 | + # Verify hash matches filename |
| 47 | + if computed_hash != expected_hash: |
| 48 | + raise ValueError( |
| 49 | + f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'" |
| 50 | + ) |
| 51 | + |
| 52 | + try: |
| 53 | + # Load the tensor using torch.load (tensors are saved with torch.save) |
| 54 | + # If device is None, keep tensor on its original device, otherwise move to specified device |
| 55 | + tensor = torch.load(blob_path, map_location=device) |
| 56 | + return tensor |
| 57 | + except Exception as e: |
| 58 | + raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}") |
0 commit comments