Skip to content

convert : ability to lazy-load safetensors remotely without downloading to disk #12820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 10, 2025
Merged
37 changes: 29 additions & 8 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")

Expand All @@ -83,11 +83,23 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.lazy = not eager or (remote_hf_model_id is not None)
if remote_hf_model_id is not None:
self.is_safetensors = True

def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))

self.get_tensors = get_remote_tensors
else:
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
Expand Down Expand Up @@ -5393,6 +5405,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)

@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
return cast(torch.Tensor, lazy)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
Expand Down Expand Up @@ -5516,8 +5536,9 @@ def main() -> None:

if args.remote:
from huggingface_hub import snapshot_download
args.remote = str(dir_model)
local_dir = snapshot_download(
repo_id=str(dir_model),
repo_id=args.remote,
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
Expand Down Expand Up @@ -5569,7 +5590,7 @@ def main() -> None:
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split, remote_hf_model_id=args.remote or None)

if args.vocab_only:
logger.info("Exporting model vocab...")
Expand Down
32 changes: 22 additions & 10 deletions gguf-py/gguf/utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import json
Expand Down Expand Up @@ -71,6 +72,20 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"


@dataclass
class RemoteTensor:
dtype: str
shape: tuple[int, ...]
offset_start: int
size: int
url: str

def data(self) -> bytes:
# TODO: handle request errors (maybe with limited retries?)
data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)
return data


class SafetensorRemote:
"""
Uility class to handle remote safetensor files.
Expand All @@ -94,7 +109,7 @@ class SafetensorRemote:
ALIGNMENT = 8 # bytes

@classmethod
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]:
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a Hugging Face model repository.

Expand All @@ -105,10 +120,7 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
if is_single_file:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
tensors: dict[str, tuple[str, list[int], int, int, str]] = {}
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
return tensors
return cls.get_list_tensors(url)

# case 2: model has multiple files
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
Expand All @@ -124,25 +136,25 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i
all_files = list(set(weight_map.values()))
all_files.sort() # make sure we load shard files in order
# get the list of tensors
tensors = {}
tensors: dict[str, RemoteTensor] = {}
for file in all_files:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
tensors[key] = val
return tensors

raise ValueError(f"Model {model_id} does not have any safetensor files")

@classmethod
def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]:
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a remote safetensor file.

Returns a dictionary of tensor names and their metadata.
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
"""
metadata, data_start_offset = cls.get_metadata(url)
res: dict[str, tuple[str, list[int], int, int]] = {}
res: dict[str, RemoteTensor] = {}

for name, meta in metadata.items():
if name == "__metadata__":
Expand All @@ -155,7 +167,7 @@ def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = (dtype, shape, offset_start, size)
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")

Expand Down
Loading