-
Notifications
You must be signed in to change notification settings - Fork 12.4k
convert : ability to lazy-load safetensors remotely without downloading to disk #12820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
2507c8e
7f61d0b
08ecbbe
3a3682d
df95a3a
4f65762
b584e39
78094fc
4c0170e
42fc895
63f0604
c8760cc
2e535f6
e8b7d26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Literal | ||
|
||
import json | ||
|
||
|
||
def fill_templated_filename(filename: str, output_type: str | None) -> str: | ||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' | ||
|
@@ -67,3 +70,184 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st | |
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" | ||
|
||
return f"{name}{parameters}{finetune}{version}{encoding}{kind}" | ||
|
||
|
||
@dataclass | ||
class RemoteTensor: | ||
dtype: str | ||
shape: tuple[int, ...] | ||
offset_start: int | ||
size: int | ||
url: str | ||
|
||
def data(self) -> bytes: | ||
# TODO: handle request errors (maybe with limited retries?) | ||
data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size) | ||
return data | ||
|
||
|
||
class SafetensorRemote: | ||
""" | ||
Uility class to handle remote safetensor files. | ||
This class is designed to work with Hugging Face model repositories. | ||
|
||
Example (one model has single safetensor file, the other has multiple): | ||
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: | ||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) | ||
print(json.dumps(tensors, indent=2)) | ||
|
||
Example reading tensor data: | ||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) | ||
for name, meta in tensors.items(): | ||
dtype, shape, offset_start, size, remote_safetensor_url = meta | ||
# read the tensor data | ||
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) | ||
print(data) | ||
""" | ||
|
||
BASE_DOMAIN = "https://huggingface.co" | ||
ALIGNMENT = 8 # bytes | ||
|
||
@classmethod | ||
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: | ||
""" | ||
Get list of tensors from a Hugging Face model repository. | ||
|
||
Returns a dictionary of tensor names and their metadata. | ||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) | ||
""" | ||
# case 1: model has only one single model.safetensor file | ||
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") | ||
if is_single_file: | ||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" | ||
return cls.get_list_tensors(url) | ||
|
||
# case 2: model has multiple files | ||
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" | ||
is_multiple_files = cls.check_file_exist(index_url) | ||
if is_multiple_files: | ||
# read the index file | ||
index_data = cls.get_data_by_range(index_url, 0) | ||
index_str = index_data.decode('utf-8') | ||
index_json = json.loads(index_str) | ||
assert index_json.get("weight_map") is not None, "weight_map not found in index file" | ||
weight_map = index_json["weight_map"] | ||
# get the list of files | ||
all_files = list(set(weight_map.values())) | ||
all_files.sort() # make sure we load shard files in order | ||
# get the list of tensors | ||
tensors: dict[str, RemoteTensor] = {} | ||
for file in all_files: | ||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" | ||
for key, val in cls.get_list_tensors(url).items(): | ||
tensors[key] = val | ||
return tensors | ||
|
||
raise ValueError(f"Model {model_id} does not have any safetensor files") | ||
|
||
@classmethod | ||
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: | ||
""" | ||
Get list of tensors from a remote safetensor file. | ||
|
||
Returns a dictionary of tensor names and their metadata. | ||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size) | ||
""" | ||
metadata, data_start_offset = cls.get_metadata(url) | ||
res: dict[str, RemoteTensor] = {} | ||
|
||
for name, meta in metadata.items(): | ||
if name == "__metadata__": | ||
continue | ||
if not isinstance(meta, dict): | ||
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") | ||
try: | ||
dtype = meta["dtype"] | ||
shape = meta["shape"] | ||
offset_start_relative, offset_end_relative = meta["data_offsets"] | ||
size = offset_end_relative - offset_start_relative | ||
offset_start = data_start_offset + offset_start_relative | ||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) | ||
except KeyError as e: | ||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") | ||
|
||
return res | ||
|
||
@classmethod | ||
def get_metadata(cls, url: str) -> tuple[dict, int]: | ||
""" | ||
Get JSON metadata from a remote safetensor file. | ||
|
||
Returns tuple of (metadata, data_start_offset) | ||
""" | ||
# Request first 5MB of the file (hopefully enough for metadata) | ||
read_size = 5 * 1024 * 1024 | ||
raw_data = cls.get_data_by_range(url, 0, read_size) | ||
|
||
# Parse header | ||
# First 8 bytes contain the metadata length as u64 little-endian | ||
if len(raw_data) < 8: | ||
raise ValueError("Not enough data to read metadata size") | ||
metadata_length = int.from_bytes(raw_data[:8], byteorder='little') | ||
|
||
# Calculate the data start offset | ||
data_start_offset = 8 + metadata_length | ||
alignment = SafetensorRemote.ALIGNMENT | ||
if data_start_offset % alignment != 0: | ||
data_start_offset += alignment - (data_start_offset % alignment) | ||
|
||
# Check if we have enough data to read the metadata | ||
if len(raw_data) < 8 + metadata_length: | ||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") | ||
|
||
# Extract metadata bytes and parse as JSON | ||
metadata_bytes = raw_data[8:8 + metadata_length] | ||
metadata_str = metadata_bytes.decode('utf-8') | ||
try: | ||
metadata = json.loads(metadata_str) | ||
return metadata, data_start_offset | ||
except json.JSONDecodeError as e: | ||
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") | ||
|
||
@classmethod | ||
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: | ||
""" | ||
Get raw byte data from a remote file by range. | ||
If size is not specified, it will read the entire file. | ||
""" | ||
import requests | ||
from urllib.parse import urlparse | ||
|
||
parsed_url = urlparse(url) | ||
if not parsed_url.scheme or not parsed_url.netloc: | ||
raise ValueError(f"Invalid URL: {url}") | ||
|
||
headers = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it can still be useful to specify the User-Agent as in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for noticing, done in e8b7d26 |
||
if size > -1: | ||
headers = {"Range": f"bytes={start}-{start + size}"} | ||
response = requests.get(url, allow_redirects=True, headers=headers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another idea could be to have this LOC multithreaded if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 42fc895 (I did a vibe code with gemini 2.5 pro) |
||
response.raise_for_status() | ||
|
||
# Get raw byte data | ||
return response.content[:size] | ||
|
||
@classmethod | ||
def check_file_exist(cls, url: str) -> bool: | ||
""" | ||
Check if a file exists at the given URL. | ||
Returns True if the file exists, False otherwise. | ||
""" | ||
import requests | ||
from urllib.parse import urlparse | ||
|
||
parsed_url = urlparse(url) | ||
if not parsed_url.scheme or not parsed_url.netloc: | ||
raise ValueError(f"Invalid URL: {url}") | ||
|
||
try: | ||
headers = {"Range": "bytes=0-0"} | ||
response = requests.head(url, allow_redirects=True, headers=headers) | ||
# Success (2xx) or redirect (3xx) | ||
return 200 <= response.status_code < 400 | ||
except requests.RequestException: | ||
return False |
Uh oh!
There was an error while loading. Please reload this page.