From 2507c8e24f328e69355745d2103d2e2b6808e6f6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 12:25:28 +0200 Subject: [PATCH 01/14] gguf util : add SafetensorRemote --- gguf-py/gguf/utility.py | 171 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index ae92d786a4068..720f6af2848f0 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -2,6 +2,8 @@ from typing import Literal +import json + def fill_templated_filename(filename: str, output_type: str | None) -> str: # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' @@ -67,3 +69,172 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" return f"{name}{parameters}{finetune}{version}{encoding}{kind}" + +class SafetensorRemote: + """ + Uility class to handle remote safetensor files. + This class is designed to work with Hugging Face model repositories. + + Example (one model has single safetensor file, the other has multiple): + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + print(json.dumps(tensors, indent=2)) + + Example reading tensor data: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + for name, meta in tensors.items(): + dtype, shape, offset_start, size, remote_safetensor_url = meta + # read the tensor data + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) + print(data) + """ + + BASE_DOMAIN = "https://huggingface.co" + ALIGNMENT = 8 # bytes + + @classmethod + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]: + """ + Get list of tensors from a Hugging Face model repository. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) + """ + # case 1: model has only one single model.safetensor file + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + if is_single_file: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" + tensors: dict[str, tuple[str, list[int], int, int, str]] = {} + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = (*val, url) # populate the url + return tensors + + # case 2: model has multiple files + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + is_multiple_files = cls.check_file_exist(index_url) + if is_multiple_files: + # read the index file + index_data = cls.get_data_by_range(index_url, 0) + index_str = index_data.decode('utf-8') + index_json = json.loads(index_str) + assert index_json.get("weight_map") is not None, "weight_map not found in index file" + weight_map = index_json["weight_map"] + # get the list of files + all_files = list(set(weight_map.values())) + all_files.sort() # make sure we load shard files in order + # get the list of tensors + tensors = {} + for file in all_files: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = (*val, url) # populate the url + return tensors + + raise ValueError(f"Model {model_id} does not have any safetensor files") + + @classmethod + def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]: + """ + Get list of tensors from a remote safetensor file. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) + """ + metadata, data_start_offset = cls.get_metadata(url) + res: dict[str, tuple[str, list[int], int, int]] = {} + + for name, meta in metadata.items(): + if name == "__metadata__": + continue + if not isinstance(meta, dict): + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") + try: + dtype = meta["dtype"] + shape = meta["shape"] + offset_start_relative, offset_end_relative = meta["data_offsets"] + size = offset_end_relative - offset_start_relative + offset_start = data_start_offset + offset_start_relative + res[name] = (dtype, shape, offset_start, size) + except KeyError as e: + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") + + return res + + @classmethod + def get_metadata(cls, url: str) -> tuple[dict, int]: + """ + Get JSON metadata from a remote safetensor file. + + Returns tuple of (metadata, data_start_offset) + """ + # Request first 5MB of the file (hopefully enough for metadata) + read_size = 5 * 1024 * 1024 + raw_data = cls.get_data_by_range(url, 0, read_size) + + # Parse header + # First 8 bytes contain the metadata length as u64 little-endian + if len(raw_data) < 8: + raise ValueError("Not enough data to read metadata size") + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') + + # Calculate the data start offset + data_start_offset = 8 + metadata_length + alignment = SafetensorRemote.ALIGNMENT + if data_start_offset % alignment != 0: + data_start_offset += alignment - (data_start_offset % alignment) + + # Check if we have enough data to read the metadata + if len(raw_data) < 8 + metadata_length: + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") + + # Extract metadata bytes and parse as JSON + metadata_bytes = raw_data[8:8 + metadata_length] + metadata_str = metadata_bytes.decode('utf-8') + try: + metadata = json.loads(metadata_str) + return metadata, data_start_offset + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + + @classmethod + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: + """ + Get raw byte data from a remote file by range. + If size is not specified, it will read the entire file. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + headers = {} + if size > -1: + headers = {"Range": f"bytes={start}-{start + size}"} + response = requests.get(url, allow_redirects=True, headers=headers) + response.raise_for_status() + + # Get raw byte data + return response.content[:size] + + @classmethod + def check_file_exist(cls, url: str) -> bool: + """ + Check if a file exists at the given URL. + Returns True if the file exists, False otherwise. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + try: + headers = {"Range": f"bytes=0-0"} + response = requests.head(url, allow_redirects=True, headers=headers) + # Success (2xx) or redirect (3xx) + return 200 <= response.status_code < 400 + except requests.RequestException: + return False From 7f61d0b22f97608d77a782ae7ecdc4f686ce35fc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 12:29:20 +0200 Subject: [PATCH 02/14] fix style --- gguf-py/gguf/utility.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 720f6af2848f0..f68a7fbacfb80 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -70,6 +70,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st return f"{name}{parameters}{finetune}{version}{encoding}{kind}" + class SafetensorRemote: """ Uility class to handle remote safetensor files. @@ -232,7 +233,7 @@ def check_file_exist(cls, url: str) -> bool: raise ValueError(f"Invalid URL: {url}") try: - headers = {"Range": f"bytes=0-0"} + headers = {"Range": "bytes=0-0"} response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400 From 08ecbbe398af5d352c637b2895516d36a120e3d0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 13:28:55 +0200 Subject: [PATCH 03/14] convert: add --remote option --- convert_hf_to_gguf.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9549900206b48..638a8e2b34674 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5470,6 +5470,10 @@ def parse_args() -> argparse.Namespace: "--print-supported-models", action="store_true", help="Print the supported models" ) + parser.add_argument( + "--remote", action="store_true", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B'", + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5510,6 +5514,14 @@ def main() -> None: dir_model = args.model + if args.remote: + from huggingface_hub import snapshot_download + local_dir = snapshot_download( + repo_id=str(dir_model), + allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + dir_model = Path(local_dir) + logger.info(f"Downloaded config and tokenizer to {local_dir}") + if not dir_model.is_dir(): logger.error(f'Error: {args.model} is not a directory') sys.exit(1) @@ -5531,6 +5543,9 @@ def main() -> None: if args.outfile is not None: fname_out = args.outfile + elif args.remote: + # if remote, use the model ID as the output file name + fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf") else: fname_out = dir_model From 3a3682de0b57f4aaa57439c0fd4231e20a64a8f6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Apr 2025 10:26:24 -0400 Subject: [PATCH 04/14] convert : allow using lazy remote tensors It's a bit slow for now since everything is blocking and single-threaded. --- convert_hf_to_gguf.py | 37 +++++++++++++++++++++++++++++-------- gguf-py/gguf/utility.py | 32 ++++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 638a8e2b34674..465c411f16bd2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -83,11 +83,23 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file - self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") - self.is_safetensors = len(self.part_names) > 0 - if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + self.lazy = not eager or (remote_hf_model_id is not None) + if remote_hf_model_id is not None: + self.is_safetensors = True + + def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + self.tensor_names = set(name for name in remote_tensors.keys()) + for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items(): + yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) + + self.get_tensors = get_remote_tensors + else: + self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") + self.is_safetensors = len(self.part_names) > 0 + if not self.is_safetensors: + self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -5393,6 +5405,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor: lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) return cast(torch.Tensor, lazy) + @classmethod + def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): + dtype = cls._dtype_str_map[remote_tensor.dtype] + shape = remote_tensor.shape + meta = cls.meta_with_dtype_and_shape(dtype, shape) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + return cast(torch.Tensor, lazy) + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): del types # unused @@ -5516,8 +5536,9 @@ def main() -> None: if args.remote: from huggingface_hub import snapshot_download + args.remote = str(dir_model) local_dir = snapshot_download( - repo_id=str(dir_model), + repo_id=args.remote, allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") @@ -5569,7 +5590,7 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split) + small_first_shard=args.no_tensor_first_split, remote_hf_model_id=args.remote or None) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index f68a7fbacfb80..5bae22e7550bf 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Literal import json @@ -71,6 +72,20 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st return f"{name}{parameters}{finetune}{version}{encoding}{kind}" +@dataclass +class RemoteTensor: + dtype: str + shape: tuple[int, ...] + offset_start: int + size: int + url: str + + def data(self) -> bytes: + # TODO: handle request errors (maybe with limited retries?) + data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size) + return data + + class SafetensorRemote: """ Uility class to handle remote safetensor files. @@ -94,7 +109,7 @@ class SafetensorRemote: ALIGNMENT = 8 # bytes @classmethod - def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]: + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ Get list of tensors from a Hugging Face model repository. @@ -105,10 +120,7 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") if is_single_file: url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" - tensors: dict[str, tuple[str, list[int], int, int, str]] = {} - for key, val in cls.get_list_tensors(url).items(): - tensors[key] = (*val, url) # populate the url - return tensors + return cls.get_list_tensors(url) # case 2: model has multiple files index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" @@ -124,17 +136,17 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i all_files = list(set(weight_map.values())) all_files.sort() # make sure we load shard files in order # get the list of tensors - tensors = {} + tensors: dict[str, RemoteTensor] = {} for file in all_files: url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" for key, val in cls.get_list_tensors(url).items(): - tensors[key] = (*val, url) # populate the url + tensors[key] = val return tensors raise ValueError(f"Model {model_id} does not have any safetensor files") @classmethod - def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]: + def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: """ Get list of tensors from a remote safetensor file. @@ -142,7 +154,7 @@ def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int] Each tensor is represented as a tuple of (dtype, shape, offset_start, size) """ metadata, data_start_offset = cls.get_metadata(url) - res: dict[str, tuple[str, list[int], int, int]] = {} + res: dict[str, RemoteTensor] = {} for name, meta in metadata.items(): if name == "__metadata__": @@ -155,7 +167,7 @@ def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int] offset_start_relative, offset_end_relative = meta["data_offsets"] size = offset_end_relative - offset_start_relative offset_start = data_start_offset + offset_start_relative - res[name] = (dtype, shape, offset_start, size) + res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) except KeyError as e: raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") From df95a3a7b38282bce0f7c9bfe764b84bd119c3b9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 16:55:20 +0200 Subject: [PATCH 05/14] correct metadata.name --- convert_hf_to_gguf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 465c411f16bd2..64e653e79ede4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -65,6 +65,7 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path + remote_hf_model_id: str | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -84,6 +85,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager or (remote_hf_model_id is not None) + self.remote_hf_model_id = remote_hf_model_id if remote_hf_model_id is not None: self.is_safetensors = True @@ -405,6 +407,10 @@ def prepare_metadata(self, vocab_only: bool): self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) + # If we are using HF model id, set the metadata name to the model id + if self.remote_hf_model_id: + self.metadata.name = self.remote_hf_model_id + # Fallback to model directory name if metadata name is still missing if self.metadata.name is None: self.metadata.name = self.dir_model.name From 4f657627689a1203d08c712b11bc6ae1b7919444 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 17:03:14 +0200 Subject: [PATCH 06/14] small style fix --- convert_hf_to_gguf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 64e653e79ede4..008a41764ad91 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5498,7 +5498,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--remote", action="store_true", - help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B'", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'", ) args = parser.parse_args() @@ -5542,9 +5542,8 @@ def main() -> None: if args.remote: from huggingface_hub import snapshot_download - args.remote = str(dir_model) local_dir = snapshot_download( - repo_id=args.remote, + repo_id=str(dir_model), allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") @@ -5596,7 +5595,8 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split, remote_hf_model_id=args.remote or None) + small_first_shard=args.no_tensor_first_split, + remote_hf_model_id=str(args.model) if args.remote else None) if args.vocab_only: logger.info("Exporting model vocab...") From b584e3979dec9f54b9278bef4307d4abf44ac108 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 17:24:28 +0200 Subject: [PATCH 07/14] support HF_TOKEN --- convert_hf_to_gguf.py | 2 +- gguf-py/gguf/utility.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 008a41764ad91..72d2f7a815bf8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5498,7 +5498,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--remote", action="store_true", - help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'", + help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", ) args = parser.parse_args() diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 5bae22e7550bf..62ccc4f9303a8 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Literal +import os import json @@ -94,7 +95,7 @@ class SafetensorRemote: Example (one model has single safetensor file, the other has multiple): for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) - print(json.dumps(tensors, indent=2)) + print(tensors) Example reading tensor data: tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) @@ -223,8 +224,10 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: raise ValueError(f"Invalid URL: {url}") headers = {} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" if size > -1: - headers = {"Range": f"bytes={start}-{start + size}"} + headers["Range"] = f"bytes={start}-{start + size}" response = requests.get(url, allow_redirects=True, headers=headers) response.raise_for_status() @@ -246,6 +249,8 @@ def check_file_exist(cls, url: str) -> bool: try: headers = {"Range": "bytes=0-0"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400 From 78094fc4a7d9077de53446b290b6ee8f6e85736c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Apr 2025 11:40:12 -0400 Subject: [PATCH 08/14] convert : use writeable buffer for remote lazy tensors --- convert_hf_to_gguf.py | 3 ++- gguf-py/gguf/utility.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 72d2f7a815bf8..b47a3d85ec4a7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5416,7 +5416,8 @@ def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): dtype = cls._dtype_str_map[remote_tensor.dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + func = lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape) + lazy = cls(meta=meta, args=(remote_tensor,), func=func) return cast(torch.Tensor, lazy) @classmethod diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 62ccc4f9303a8..89a9d494db0ee 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -81,9 +81,10 @@ class RemoteTensor: size: int url: str - def data(self) -> bytes: + def data(self) -> bytearray: # TODO: handle request errors (maybe with limited retries?) - data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size) + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) return data From 4c0170e20642311dadddaf796185a0e60da5c363 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Apr 2025 11:47:58 -0400 Subject: [PATCH 09/14] convert : fix flake8 lint regarding lamdba assigment --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b47a3d85ec4a7..72d2f7a815bf8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5416,8 +5416,7 @@ def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): dtype = cls._dtype_str_map[remote_tensor.dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - func = lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=func) + lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) return cast(torch.Tensor, lazy) @classmethod From 42fc895ace385edc972ad819c76c704aeea61791 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 18:05:48 +0200 Subject: [PATCH 10/14] multithreaded download --- gguf-py/gguf/utility.py | 157 ++++++++++++++++++++++++++++++++++------ 1 file changed, 135 insertions(+), 22 deletions(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 89a9d494db0ee..c42b9cb1505a1 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,10 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import Literal, Any import os import json +import requests +import threading +from urllib.parse import urlparse def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -110,6 +113,10 @@ class SafetensorRemote: BASE_DOMAIN = "https://huggingface.co" ALIGNMENT = 8 # bytes + # start using multithread download for files larger than 100MB + MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 + MULTITHREAD_COUNT = 8 # number of threads + @classmethod def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ @@ -211,29 +218,139 @@ def get_metadata(cls, url: str) -> tuple[dict, int]: except json.JSONDecodeError as e: raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + @classmethod + def _get_request_headers(cls) -> dict[str, str]: + """Prepare common headers for requests.""" + headers = {"User-Agent": "convert_hf_to_gguf"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + return headers + @classmethod def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: """ - Get raw byte data from a remote file by range. - If size is not specified, it will read the entire file. - """ - import requests - from urllib.parse import urlparse + Get raw byte data from a remote file by range using single or multi-threaded download. + If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). + If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. + Otherwise, it uses a single request. + """ parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") - headers = {} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" - if size > -1: - headers["Range"] = f"bytes={start}-{start + size}" - response = requests.get(url, allow_redirects=True, headers=headers) - response.raise_for_status() - - # Get raw byte data - return response.content[:size] + common_headers = cls._get_request_headers() + + # --- Multithreading Path --- + if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: + # print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB") + num_threads = cls.MULTITHREAD_COUNT + results: list[Any] = [None] * num_threads # Store results or exceptions + threads: list[threading.Thread] = [] + + def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): + """Worker function for thread.""" + thread_headers = headers.copy() + # Range header is inclusive end byte + range_end = chunk_start + chunk_size - 1 + thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" + try: + # Using stream=False should make requests wait for content download + response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout + response.raise_for_status() # Check for HTTP errors + + content = response.content + if len(content) != chunk_size: + # This is a critical check + raise IOError( + f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " + f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" + ) + result_list[index] = content + except Exception as e: + # Store exception to be raised by the main thread + # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print + result_list[index] = e + + # Calculate chunk sizes and create/start threads + base_chunk_size = size // num_threads + remainder = size % num_threads + current_offset = start + + for i in range(num_threads): + chunk_size = base_chunk_size + (1 if i < remainder else 0) + if chunk_size == 0: # Should not happen if size >= threshold but handle defensively + results[i] = b"" # Store empty bytes for this "chunk" + continue + + thread = threading.Thread( + target=download_chunk, + args=(url, current_offset, chunk_size, i, results, common_headers), + daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) + ) + threads.append(thread) + thread.start() + current_offset += chunk_size # Move offset for the next chunk + + # Wait for all threads to complete + for i, thread in enumerate(threads): + thread.join() # Wait indefinitely for each thread + + # Check results for errors and concatenate chunks + final_data_parts = [] + for i in range(num_threads): + result = results[i] + if isinstance(result, Exception): + # Raise the first exception encountered + raise result + elif result is None: + # This indicates a thread finished without setting its result or exception (unexpected) + # Check if it was supposed to download anything + expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) + if expected_chunk_size > 0: + raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") + else: + final_data_parts.append(b"") # Append empty bytes for zero-size chunk + else: + final_data_parts.append(result) + + # Combine the byte chunks + final_data = b"".join(final_data_parts) + + # Final validation: Does the combined size match the requested size? + if len(final_data) != size: + raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") + + return final_data + + # --- Single-threaded Path --- + else: + # print(f"Using single thread for size {size}") # Optional debug print + headers = common_headers.copy() + if size > -1: + # Range header uses inclusive end byte + range_end = start + size - 1 + headers["Range"] = f"bytes={start}-{range_end}" + elif start > 0: + # Request from start offset to the end of the file + headers["Range"] = f"bytes={start}-" + # If start=0 and size=-1, no Range header is needed (get full file) + + response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout + response.raise_for_status() + content = response.content + + # Validate downloaded size if a specific size was requested + if size > -1 and len(content) != size: + # Check status code - 206 Partial Content is expected for successful range requests + status_code = response.status_code + content_range = response.headers.get('Content-Range') + raise IOError( + f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " + f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" + ) + + return content @classmethod def check_file_exist(cls, url: str) -> bool: @@ -241,17 +358,13 @@ def check_file_exist(cls, url: str) -> bool: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ - import requests - from urllib.parse import urlparse - parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") try: - headers = {"Range": "bytes=0-0"} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + headers = cls._get_request_headers() + headers["Range"] = "bytes=0-0" # Request a small range to check existence response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400 From 63f0604a18065f49d2b7b158788faf82ce8e40fe Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 18:07:43 +0200 Subject: [PATCH 11/14] multithread: print debug --- gguf-py/gguf/utility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index c42b9cb1505a1..1b08d73d71a8f 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -243,7 +243,7 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: # --- Multithreading Path --- if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: - # print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB") + print(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") num_threads = cls.MULTITHREAD_COUNT results: list[Any] = [None] * num_threads # Store results or exceptions threads: list[threading.Thread] = [] From c8760ccb1a1fd25a2fc1e8bfc650f4f469898a51 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 18:28:13 +0200 Subject: [PATCH 12/14] fix style --- gguf-py/gguf/utility.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 1b08d73d71a8f..620030df15e7c 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -7,6 +7,7 @@ import json import requests import threading +import logging from urllib.parse import urlparse @@ -110,6 +111,8 @@ class SafetensorRemote: print(data) """ + logger = logging.getLogger("safetensor_remote") + BASE_DOMAIN = "https://huggingface.co" ALIGNMENT = 8 # bytes @@ -243,7 +246,7 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: # --- Multithreading Path --- if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: - print(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") + cls.logger.info(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") num_threads = cls.MULTITHREAD_COUNT results: list[Any] = [None] * num_threads # Store results or exceptions threads: list[threading.Thread] = [] @@ -308,9 +311,9 @@ def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int # Check if it was supposed to download anything expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) if expected_chunk_size > 0: - raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") + raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") else: - final_data_parts.append(b"") # Append empty bytes for zero-size chunk + final_data_parts.append(b"") # Append empty bytes for zero-size chunk else: final_data_parts.append(result) @@ -319,7 +322,7 @@ def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int # Final validation: Does the combined size match the requested size? if len(final_data) != size: - raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") + raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") return final_data From 2e535f6fbaad54ca2200d0161f69c0111df6de81 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Apr 2025 14:31:51 +0200 Subject: [PATCH 13/14] Revert "multithreaded download" This reverts commit 42fc895ace385edc972ad819c76c704aeea61791. --- gguf-py/gguf/utility.py | 160 ++++++---------------------------------- 1 file changed, 22 insertions(+), 138 deletions(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 620030df15e7c..89a9d494db0ee 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,14 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, Any +from typing import Literal import os import json -import requests -import threading -import logging -from urllib.parse import urlparse def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -111,15 +107,9 @@ class SafetensorRemote: print(data) """ - logger = logging.getLogger("safetensor_remote") - BASE_DOMAIN = "https://huggingface.co" ALIGNMENT = 8 # bytes - # start using multithread download for files larger than 100MB - MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 - MULTITHREAD_COUNT = 8 # number of threads - @classmethod def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ @@ -221,139 +211,29 @@ def get_metadata(cls, url: str) -> tuple[dict, int]: except json.JSONDecodeError as e: raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") - @classmethod - def _get_request_headers(cls) -> dict[str, str]: - """Prepare common headers for requests.""" - headers = {"User-Agent": "convert_hf_to_gguf"} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" - return headers - @classmethod def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: """ - Get raw byte data from a remote file by range using single or multi-threaded download. - - If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). - If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. - Otherwise, it uses a single request. + Get raw byte data from a remote file by range. + If size is not specified, it will read the entire file. """ + import requests + from urllib.parse import urlparse + parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") - common_headers = cls._get_request_headers() - - # --- Multithreading Path --- - if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: - cls.logger.info(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") - num_threads = cls.MULTITHREAD_COUNT - results: list[Any] = [None] * num_threads # Store results or exceptions - threads: list[threading.Thread] = [] - - def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): - """Worker function for thread.""" - thread_headers = headers.copy() - # Range header is inclusive end byte - range_end = chunk_start + chunk_size - 1 - thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" - try: - # Using stream=False should make requests wait for content download - response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout - response.raise_for_status() # Check for HTTP errors - - content = response.content - if len(content) != chunk_size: - # This is a critical check - raise IOError( - f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " - f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" - ) - result_list[index] = content - except Exception as e: - # Store exception to be raised by the main thread - # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print - result_list[index] = e - - # Calculate chunk sizes and create/start threads - base_chunk_size = size // num_threads - remainder = size % num_threads - current_offset = start - - for i in range(num_threads): - chunk_size = base_chunk_size + (1 if i < remainder else 0) - if chunk_size == 0: # Should not happen if size >= threshold but handle defensively - results[i] = b"" # Store empty bytes for this "chunk" - continue - - thread = threading.Thread( - target=download_chunk, - args=(url, current_offset, chunk_size, i, results, common_headers), - daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) - ) - threads.append(thread) - thread.start() - current_offset += chunk_size # Move offset for the next chunk - - # Wait for all threads to complete - for i, thread in enumerate(threads): - thread.join() # Wait indefinitely for each thread - - # Check results for errors and concatenate chunks - final_data_parts = [] - for i in range(num_threads): - result = results[i] - if isinstance(result, Exception): - # Raise the first exception encountered - raise result - elif result is None: - # This indicates a thread finished without setting its result or exception (unexpected) - # Check if it was supposed to download anything - expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) - if expected_chunk_size > 0: - raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") - else: - final_data_parts.append(b"") # Append empty bytes for zero-size chunk - else: - final_data_parts.append(result) - - # Combine the byte chunks - final_data = b"".join(final_data_parts) - - # Final validation: Does the combined size match the requested size? - if len(final_data) != size: - raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") - - return final_data - - # --- Single-threaded Path --- - else: - # print(f"Using single thread for size {size}") # Optional debug print - headers = common_headers.copy() - if size > -1: - # Range header uses inclusive end byte - range_end = start + size - 1 - headers["Range"] = f"bytes={start}-{range_end}" - elif start > 0: - # Request from start offset to the end of the file - headers["Range"] = f"bytes={start}-" - # If start=0 and size=-1, no Range header is needed (get full file) - - response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout - response.raise_for_status() - content = response.content - - # Validate downloaded size if a specific size was requested - if size > -1 and len(content) != size: - # Check status code - 206 Partial Content is expected for successful range requests - status_code = response.status_code - content_range = response.headers.get('Content-Range') - raise IOError( - f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " - f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" - ) - - return content + headers = {} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + if size > -1: + headers["Range"] = f"bytes={start}-{start + size}" + response = requests.get(url, allow_redirects=True, headers=headers) + response.raise_for_status() + + # Get raw byte data + return response.content[:size] @classmethod def check_file_exist(cls, url: str) -> bool: @@ -361,13 +241,17 @@ def check_file_exist(cls, url: str) -> bool: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ + import requests + from urllib.parse import urlparse + parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") try: - headers = cls._get_request_headers() - headers["Range"] = "bytes=0-0" # Request a small range to check existence + headers = {"Range": "bytes=0-0"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400 From e8b7d263ab11a275252eb8c651987772d9923077 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Apr 2025 17:36:24 +0200 Subject: [PATCH 14/14] bring back _get_request_headers --- gguf-py/gguf/utility.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 89a9d494db0ee..e5251aef8c832 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -224,9 +224,7 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") - headers = {} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + headers = cls._get_request_headers() if size > -1: headers["Range"] = f"bytes={start}-{start + size}" response = requests.get(url, allow_redirects=True, headers=headers) @@ -249,11 +247,18 @@ def check_file_exist(cls, url: str) -> bool: raise ValueError(f"Invalid URL: {url}") try: - headers = {"Range": "bytes=0-0"} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + headers = cls._get_request_headers() + headers["Range"] = "bytes=0-0" response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400 except requests.RequestException: return False + + @classmethod + def _get_request_headers(cls) -> dict[str, str]: + """Prepare common headers for requests.""" + headers = {"User-Agent": "convert_hf_to_gguf"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + return headers