Skip to content

convert : ability to lazy-load safetensors remotely without downloading to disk #12820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 10, 2025
Merged
15 changes: 15 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5470,6 +5470,10 @@ def parse_args() -> argparse.Namespace:
"--print-supported-models", action="store_true",
help="Print the supported models"
)
parser.add_argument(
"--remote", action="store_true",
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B'",
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
Expand Down Expand Up @@ -5510,6 +5514,14 @@ def main() -> None:

dir_model = args.model

if args.remote:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
repo_id=str(dir_model),
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")

if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
Expand All @@ -5531,6 +5543,9 @@ def main() -> None:

if args.outfile is not None:
fname_out = args.outfile
elif args.remote:
# if remote, use the model ID as the output file name
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
else:
fname_out = dir_model

Expand Down
172 changes: 172 additions & 0 deletions gguf-py/gguf/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Literal

import json


def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
Expand Down Expand Up @@ -67,3 +69,173 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""

return f"{name}{parameters}{finetune}{version}{encoding}{kind}"


class SafetensorRemote:
"""
Uility class to handle remote safetensor files.
This class is designed to work with Hugging Face model repositories.

Example (one model has single safetensor file, the other has multiple):
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
print(json.dumps(tensors, indent=2))

Example reading tensor data:
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
for name, meta in tensors.items():
dtype, shape, offset_start, size, remote_safetensor_url = meta
# read the tensor data
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
print(data)
"""

BASE_DOMAIN = "https://huggingface.co"
ALIGNMENT = 8 # bytes

@classmethod
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]:
"""
Get list of tensors from a Hugging Face model repository.

Returns a dictionary of tensor names and their metadata.
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
"""
# case 1: model has only one single model.safetensor file
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
if is_single_file:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
tensors: dict[str, tuple[str, list[int], int, int, str]] = {}
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
return tensors

# case 2: model has multiple files
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
is_multiple_files = cls.check_file_exist(index_url)
if is_multiple_files:
# read the index file
index_data = cls.get_data_by_range(index_url, 0)
index_str = index_data.decode('utf-8')
index_json = json.loads(index_str)
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
weight_map = index_json["weight_map"]
# get the list of files
all_files = list(set(weight_map.values()))
all_files.sort() # make sure we load shard files in order
# get the list of tensors
tensors = {}
for file in all_files:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
return tensors

raise ValueError(f"Model {model_id} does not have any safetensor files")

@classmethod
def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]:
"""
Get list of tensors from a remote safetensor file.

Returns a dictionary of tensor names and their metadata.
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
"""
metadata, data_start_offset = cls.get_metadata(url)
res: dict[str, tuple[str, list[int], int, int]] = {}

for name, meta in metadata.items():
if name == "__metadata__":
continue
if not isinstance(meta, dict):
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
try:
dtype = meta["dtype"]
shape = meta["shape"]
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = (dtype, shape, offset_start, size)
Copy link
Collaborator

@compilade compilade Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this should have all information needed.

A dataclass (or even a NamedTuple) for remote tensors would be useful since there will also need to be a function to turn that either into a Numpy ndarray or a PyTorch Tensor, whichever is simpler at first.

A lazy tensor is built from metadata and a function which produces the tensor which will be called only when the data is needed (and only once per tensor).

With such a function, it should be simpler to add a from_remote_tensor method to LazyTorchTensor, although to map the safetensors types into PyTorch types, it could be simpler to let that function live in LazyTorchTensor, and only expose a dataclass or NamedTuple for remote tensors and let LazyTorchTensor.from_remote_tensor handle the rest.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the confirmation. Could you please go ahead and implement the from_remote_tensor? Feel free to push directly to this PR, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please go ahead and implement the from_remote_tensor? Feel free to push directly to this PR, thanks!

I will, once I get somewhere more convenient (currently commuting in public transit).

except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")

return res

@classmethod
def get_metadata(cls, url: str) -> tuple[dict, int]:
"""
Get JSON metadata from a remote safetensor file.

Returns tuple of (metadata, data_start_offset)
"""
# Request first 5MB of the file (hopefully enough for metadata)
read_size = 5 * 1024 * 1024
raw_data = cls.get_data_by_range(url, 0, read_size)

# Parse header
# First 8 bytes contain the metadata length as u64 little-endian
if len(raw_data) < 8:
raise ValueError("Not enough data to read metadata size")
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')

# Calculate the data start offset
data_start_offset = 8 + metadata_length
alignment = SafetensorRemote.ALIGNMENT
if data_start_offset % alignment != 0:
data_start_offset += alignment - (data_start_offset % alignment)

# Check if we have enough data to read the metadata
if len(raw_data) < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")

# Extract metadata bytes and parse as JSON
metadata_bytes = raw_data[8:8 + metadata_length]
metadata_str = metadata_bytes.decode('utf-8')
try:
metadata = json.loads(metadata_str)
return metadata, data_start_offset
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")

@classmethod
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
"""
Get raw byte data from a remote file by range.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse

parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")

headers = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can still be useful to specify the User-Agent as in the _get_request_headers method from 42fc895

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing, done in e8b7d26

if size > -1:
headers = {"Range": f"bytes={start}-{start + size}"}
response = requests.get(url, allow_redirects=True, headers=headers)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea could be to have this LOC multithreaded if the size pass a certain threshold, but I'll have a look on this later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 42fc895

(I did a vibe code with gemini 2.5 pro)

response.raise_for_status()

# Get raw byte data
return response.content[:size]

@classmethod
def check_file_exist(cls, url: str) -> bool:
"""
Check if a file exists at the given URL.
Returns True if the file exists, False otherwise.
"""
import requests
from urllib.parse import urlparse

parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")

try:
headers = {"Range": "bytes=0-0"}
response = requests.head(url, allow_redirects=True, headers=headers)
# Success (2xx) or redirect (3xx)
return 200 <= response.status_code < 400
except requests.RequestException:
return False