Skip to content

Commit e9fecc1

Browse files
rahul-tuliclaude
andcommitted
fix: Support HuggingFace model IDs in speculators Eagle config
- Use PretrainedConfig.get_config_dict() to handle both local and HF paths - Simplifies the code and follows best practices - Tested with both local paths and HuggingFace model IDs 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 58824b0 commit e9fecc1

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

vllm/transformers_utils/configs/speculators_eagle.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,12 @@ def from_pretrained(
2828
"""
2929
Load a speculators Eagle config and convert it to vLLM format.
3030
"""
31-
config_path = Path(pretrained_model_name_or_path) / "config.json"
32-
33-
if not config_path.exists():
34-
# Fall back to standard loading if not a local path
35-
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
36-
37-
with open(config_path, "r") as f:
38-
config_dict = json.load(f)
31+
# Use the parent class method to load config dict
32+
# This handles both local paths and HuggingFace model IDs
33+
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
3934

4035
# Check if this is a speculators format config
41-
if "speculators_model_type" not in config_dict:
36+
if config_dict.get("speculators_model_type") != "eagle":
4237
# Not a speculators config, use standard loading
4338
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
4439

@@ -111,13 +106,9 @@ def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
111106
"""
112107
Check if a config file is in speculators Eagle format.
113108
"""
114-
config_file = Path(config_path) / "config.json"
115-
if not config_file.exists():
116-
return False
117-
118109
try:
119-
with open(config_file, "r") as f:
120-
config = json.load(f)
121-
return config.get("speculators_model_type") == "eagle"
110+
# Use PretrainedConfig to load from both local and HF paths
111+
config_dict, _ = PretrainedConfig.get_config_dict(config_path)
112+
return config_dict.get("speculators_model_type") == "eagle"
122113
except:
123114
return False

0 commit comments

Comments
 (0)