1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
3
from collections .abc import Iterable
4
+ from typing import Optional
4
5
5
6
import torch
6
7
import torch .nn as nn
22
23
from vllm .model_executor .models .utils import extract_layer_index
23
24
24
25
from .utils import AutoWeightsLoader , maybe_prefix
25
- from typing import Optional
26
26
27
27
logger = init_logger (__name__ )
28
28
@@ -39,8 +39,8 @@ def __init__(
39
39
quant_config : Optional [QuantizationConfig ] = None ,
40
40
) -> None :
41
41
super ().__init__ ()
42
- self .config = vllm_config . \
43
- speculative_config .draft_model_config .hf_config
42
+ self .config = (
43
+ vllm_config . speculative_config .draft_model_config .hf_config )
44
44
self .validate_and_update_config (start_layer_id , quant_config )
45
45
self .vocab_size = self .config .vocab_size
46
46
self .embed_tokens = VocabParallelEmbedding (
@@ -153,8 +153,8 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
153
153
154
154
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
155
155
nn .Module .__init__ (self )
156
- self .config = vllm_config . \
157
- speculative_config .draft_model_config .hf_config
156
+ self .config = (
157
+ vllm_config . speculative_config .draft_model_config .hf_config )
158
158
target_layer_num = vllm_config .model_config .get_num_layers (
159
159
vllm_config .parallel_config )
160
160
# draft model quantization config may differ from target model
@@ -177,7 +177,10 @@ def forward(
177
177
) -> tuple [torch .Tensor , torch .Tensor ]:
178
178
return self .model (input_ids , positions , hidden_states )
179
179
180
- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
180
+ def load_weights (
181
+ self ,
182
+ weights : Iterable [tuple [str , torch .Tensor ]]
183
+ ) -> None :
181
184
loader = AutoWeightsLoader (
182
185
self ,
183
186
# lm_head is tied with target model (Llama4ForCausalLM)
0 commit comments