Skip to content

Commit a24cb91

Browse files
authored
[Model] Fix minimax model cache & lm_head precision (#19592)
Signed-off-by: qingjun <qingjun@minimaxi.com>
1 parent 7e8d97d commit a24cb91

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vllm/model_executor/models/minimax_text_01.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def layer_fn(prefix):
856856
self._dtype = _dummy.dtype
857857
del _dummy
858858

859-
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
859+
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
860860
cache_shape=self.cache_shape)
861861

862862
rope_theta = getattr(config, "rope_theta", 10000)
@@ -1021,7 +1021,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10211021

10221022
else:
10231023
self.lm_head = PPMissingLayer()
1024-
1024+
self.lm_head.float()
10251025
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
10261026
if attn_type == 1)
10271027
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
@@ -1054,7 +1054,7 @@ def forward(self,
10541054

10551055
def compute_logits(self, hidden_states: torch.Tensor,
10561056
sampling_metadata: SamplingMetadata) -> torch.Tensor:
1057-
logits = self.logits_processor(self.lm_head, hidden_states,
1057+
logits = self.logits_processor(self.lm_head, hidden_states.float(),
10581058
sampling_metadata)
10591059

10601060
return logits

0 commit comments

Comments
 (0)