diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index a1cbe886492..b9243cd99d8 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -395,7 +395,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): padded_weight = torch.cat([ loaded_weight, torch.zeros(param.shape[0] - loaded_weight.shape[0], - *loaded_weight.shape[1:]) + *loaded_weight.shape[1:], + device=loaded_weight.device) ]) else: padded_weight = loaded_weight