|
4 | 4 | import struct
|
5 | 5 | import time
|
6 | 6 | from dataclasses import dataclass
|
7 |
| -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union |
| 7 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import requests
|
10 | 10 | import torch
|
@@ -497,7 +497,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
|
497 | 497 | servers = self.cluster_info.get_servers_by_role(ServerRole.Prefill)
|
498 | 498 | assert len(servers) == 1, \
|
499 | 499 | f"Expected only one server for {self.kv_role}, but got {len(servers)}"
|
500 |
| - prefill_infos = { |
| 500 | + prefill_infos: Dict[str, Any] = { |
501 | 501 | request_id: {
|
502 | 502 | "dp_rank": 0,
|
503 | 503 | "server_id": servers[0].server_id,
|
@@ -548,8 +548,8 @@ def start_load_kv(self, forward_context: "ForwardContext",
|
548 | 548 | kv_cache_shape = (1, 2, slen, num_heads, head_dim)
|
549 | 549 |
|
550 | 550 | uniq_req_id = self._get_unique_req_id(request.request_id)
|
551 |
| - dp_rank = prefill_infos[uniq_req_id]["dp_rank"] |
552 |
| - server_id = prefill_infos[uniq_req_id]["server_id"] |
| 551 | + dp_rank: int = prefill_infos[uniq_req_id]["dp_rank"] |
| 552 | + server_id: str = prefill_infos[uniq_req_id]["server_id"] |
553 | 553 |
|
554 | 554 | # pull kv cache from prefill node by request
|
555 | 555 | kv_hidden_dtype = kv_cache_layers[0].dtype
|
|
0 commit comments