Skip to content

Commit 153bce5

Browse files
FindHaofacebook-github-bot
authored andcommitted
add a load tensor script
Summary: This file provides a utility function for loading tensors saved by the tritonparse library. The `load_tensor` function loads a tensor from a file path and optionally verifies its integrity through hash checking. Changes Made ------------ 1. Added an optional `device` parameter to allow loading tensors to a specified device (e.g., 'cuda:0', 'cpu'). If not specified, tensors remain on their original device. 2. Made the `tensor_hash` parameter optional, allowing tensors to be loaded without hash verification when the hash is not available or verification is not needed. These changes enhance the flexibility of the tensor loading utility while maintaining backward compatibility with existing code. The function now supports three usage patterns: * Loading with both hash verification and device specification * Loading with hash verification only * Loading without verification to a specific device or the original device Reviewed By: davidberard98 Differential Revision: D77264950 fbshipit-source-id: c975a2c27d61c72152671bf75748e3139b7a099f
1 parent fd1c7cf commit 153bce5

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tritonparse/tools/load_tensor.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple tensor loading utility for tritonparse saved tensors.
4+
Usage:
5+
import tritonparse.tools.load_tensor as load_tensor
6+
tensor = load_tensor.load_tensor(tensor_file_path, device)
7+
"""
8+
9+
import hashlib
10+
from pathlib import Path
11+
12+
import torch
13+
14+
15+
def load_tensor(tensor_file_path: str, device: str = None) -> torch.Tensor:
16+
"""
17+
Load a tensor from its file path and verify its integrity using the hash in the filename.
18+
19+
Args:
20+
tensor_file_path (str): Direct path to the tensor .bin file. The filename should be
21+
the hash of the file contents followed by .bin extension.
22+
device (str, optional): Device to load the tensor to (e.g., 'cuda:0', 'cpu').
23+
If None, keeps the tensor on its original device.
24+
25+
Returns:
26+
torch.Tensor: The loaded tensor (moved to the specified device if provided)
27+
28+
Raises:
29+
FileNotFoundError: If the tensor file doesn't exist
30+
RuntimeError: If the tensor cannot be loaded
31+
ValueError: If the computed hash doesn't match the filename hash
32+
"""
33+
blob_path = Path(tensor_file_path)
34+
35+
if not blob_path.exists():
36+
raise FileNotFoundError(f"Tensor blob not found: {blob_path}")
37+
38+
# Extract expected hash from filename (remove .bin extension)
39+
expected_hash = blob_path.stem
40+
41+
# Compute actual hash of file contents
42+
with open(blob_path, "rb") as f:
43+
file_contents = f.read()
44+
computed_hash = hashlib.blake2b(file_contents).hexdigest()
45+
46+
# Verify hash matches filename
47+
if computed_hash != expected_hash:
48+
raise ValueError(
49+
f"Hash verification failed: expected '{expected_hash}' but computed '{computed_hash}'"
50+
)
51+
52+
try:
53+
# Load the tensor using torch.load (tensors are saved with torch.save)
54+
# If device is None, keep tensor on its original device, otherwise move to specified device
55+
tensor = torch.load(blob_path, map_location=device)
56+
return tensor
57+
except Exception as e:
58+
raise RuntimeError(f"Failed to load tensor from {blob_path}: {str(e)}")

0 commit comments

Comments
 (0)