1
1
from collections .abc import AsyncGenerator
2
+ from typing import Union
3
+
2
4
from vllm .beam .debug import BeamDebugInfo
3
5
from vllm .beam .penalty import PenaltyComputer
4
6
import torch
5
7
from vllm .beam .ranking import RankingComputer
8
+ from vllm .entrypoints .openai .protocol import CompletionResponse , ErrorResponse
9
+ from vllm .logger import init_logger
10
+
11
+ logger = init_logger (__name__ )
6
12
7
13
8
14
class BeamScorer :
9
15
def __init__ (self , classi_idx ):
10
16
self .penalty_computer = PenaltyComputer (classi_idx )
11
17
self .ranking_computer = RankingComputer (classi_idx )
12
18
13
- async def collapse_beams (self , responses : list [AsyncGenerator ], chunk_num = 0 , max_chunks = 4 ):
14
- debug_info = [BeamDebugInfo () for _ in responses ]
15
-
16
- scores = torch .zeros (len (responses ), dtype = torch .float )
17
-
18
- heads = [response .choices [0 ].additional_heads [0 ] for response in responses ]
19
- heads_tensor = torch .tensor (heads , dtype = torch .float )
20
- if len (heads_tensor ) > 0 :
21
- penalties = self .penalty_computer .compute (heads_tensor , debug_info )
22
- scores -= penalties
23
-
24
- ranking_scores = self .ranking_computer .compute (
19
+ async def pick_best_beam (self , responses : list [
20
+ Union [AsyncGenerator [str , None ], CompletionResponse , ErrorResponse ]]) -> Union [
21
+ AsyncGenerator [str , None ], CompletionResponse , ErrorResponse ]:
22
+ debug_info = [BeamDebugInfo () for _ in responses ]
23
+
24
+ scores = torch .zeros (len (responses ), dtype = torch .float )
25
+
26
+ heads = [response .choices [0 ].additional_heads [0 ] for response in responses ]
27
+ heads_tensor = torch .tensor (heads , dtype = torch .float )
28
+ if len (heads_tensor ) > 0 :
29
+ penalties = self .penalty_computer .compute (heads_tensor , debug_info )
30
+ scores -= penalties
31
+
32
+ ranking_scores = self .ranking_computer .compute (
25
33
heads_tensor , debug_info
26
- )
27
- scores * = ranking_scores
34
+ )
35
+ scores + = ranking_scores
28
36
29
- for i in range (len (responses )):
30
- debug_info [i ].final_score = scores [i ]
31
- debug_info [i ].content = responses [i ].choices [0 ].text
37
+ for i in range (len (responses )):
38
+ debug_info [i ].final_score = scores [i ]
39
+ debug_info [i ].content = responses [i ].choices [0 ].text
32
40
33
- print ('debug_info' , debug_info )
41
+ logger . debug ('debug_info: %s ' , debug_info )
34
42
35
- best_idx = torch .argmax (scores ).item ()
36
- return responses [best_idx ]
37
-
43
+ best_idx = torch .argmax (scores ).item ()
44
+ return responses [best_idx ]
0 commit comments