1
1
import asyncio
2
2
from collections .abc import AsyncGenerator
3
- from typing import Callable , Optional
3
+ from typing import Callable , Optional , Union
4
4
from urllib .request import Request
5
5
from vllm .beam .debug import BeamDebugInfo
6
6
import torch
7
- from vllm .entrypoints .openai .protocol import CompletionRequest
7
+ from vllm .entrypoints .openai .protocol import CompletionRequest , CompletionResponse , \
8
+ ErrorResponse
9
+ from vllm .logger import init_logger
10
+
11
+ logger = init_logger (__name__ )
8
12
9
13
10
14
def format_filter (filter_params_list ):
@@ -13,65 +17,75 @@ def format_filter(filter_params_list):
13
17
raise ValueError (f"Expect 2 keys, seeing { filter_params_list = } " )
14
18
return [dict (zip (keys , vals )) for vals in filter_params_list ]
15
19
20
+
16
21
DEFAULT_CHAR_SERVER_FILTER = format_filter (
17
- [
18
- ("annotations_porn" , 0.1098 ),
19
- ("annotations_racist" , 0.2814 ),
20
- ("annotations_disturbing" , 0.1827 ),
21
- ("annotations_harmful_promotes_selfharm" , 0.0749 ),
22
- ("annotations_harmful_promotes_terrorism" , 0.1129 ),
23
- ]
24
- )
25
-
22
+ [
23
+ ("annotations_porn" , 0.1098 ),
24
+ ("annotations_racist" , 0.2814 ),
25
+ ("annotations_disturbing" , 0.1827 ),
26
+ ("annotations_harmful_promotes_selfharm" , 0.0749 ),
27
+ ("annotations_harmful_promotes_terrorism" , 0.1129 ),
28
+ ]
29
+ )
30
+
26
31
MAX_GENERATIONS = 10
27
32
_CHUNK_SIZE = 16
28
33
34
+
29
35
class BeamValidator :
30
36
def __init__ (self , classi_idx , classifier_names ):
31
37
self .classi_idx = classi_idx
32
38
self .classifier_names = classifier_names
33
39
34
- async def get_n_valid_beams (self , create_completion : Callable , request : CompletionRequest , raw_request : Optional [Request ] = None ):
40
+ async def get_n_valid_beams (self , create_completion : Callable ,
41
+ request : CompletionRequest ,
42
+ raw_request : Optional [Request ] = None ) -> list [
43
+ Union [AsyncGenerator [str , None ], CompletionResponse , ErrorResponse ]]:
35
44
request .stream = False
36
45
n = request .n
37
46
request .n = 1
47
+ # TODO(@tanuj): accept max tokens as a parameter
38
48
request .max_tokens = _CHUNK_SIZE
39
49
request .echo = True
40
50
tasks = []
51
+ # TODO(@tanuj): deep copy request and raw_request?
41
52
for _ in range (n ):
42
53
request = request
43
54
tasks .append (create_completion (
44
55
request ,
56
+ raw_request = raw_request ,
45
57
))
46
58
res = await asyncio .gather (* tasks )
47
59
request .n = n
48
60
beam_validator_res = self .validate (res )
49
61
filtered_res = [r for r , valid in zip (res , beam_validator_res ) if valid ]
50
- print ( 'everything is filtered' , len (filtered_res ) == 0 )
62
+ logger . debug ( "Filtered count: %d" , len (filtered_res ))
51
63
if len (filtered_res ) == 0 :
52
64
return res
53
-
65
+
54
66
return filtered_res
55
-
56
- def validate (self , responses : list [AsyncGenerator ], debug_infos_G : list [BeamDebugInfo ] = None ):
57
- #TODO(@tanuj) - share this with the beam scorer
67
+
68
+ def validate (self , responses : list [AsyncGenerator ],
69
+ debug_infos_G : list [BeamDebugInfo ] = None ):
70
+ # TODO(@tanuj) - share this with the beam scorer
58
71
heads = [response .choices [0 ].additional_heads [0 ] for response in responses ]
59
72
heads_tensor = torch .tensor (heads , dtype = torch .float )
60
73
prob_GC = torch .sigmoid (heads_tensor )
61
74
valid_G = torch .ones (prob_GC .shape [0 ], dtype = torch .bool )
62
-
75
+
63
76
for g in range (heads_tensor .shape [0 ]):
64
- filtered = self .get_filtered_classifiers (prob_GC [g ], DEFAULT_CHAR_SERVER_FILTER )
77
+ filtered = self .get_filtered_classifiers (prob_GC [g ],
78
+ DEFAULT_CHAR_SERVER_FILTER )
65
79
if debug_infos_G is not None :
66
80
debug_infos_G [g ].filtered_classifiers = filtered
67
-
81
+
68
82
if filtered :
69
83
valid_G [g ] = False
70
84
for choice in responses [g ].choices :
71
85
choice .is_filtered = True
72
86
73
87
return valid_G
74
-
88
+
75
89
def get_filtered_classifiers (self , prob_C , filter_params ) -> list [str ]:
76
90
relevant_filters = [
77
91
(p ["name" ], self .classi_idx [p ["name" ]], p ["threshold" ])
@@ -87,4 +101,4 @@ def get_filtered_classifiers(self, prob_C, filter_params) -> list[str]:
87
101
if prob_C [idx ] > threshold :
88
102
ret .append (name )
89
103
90
- return ret
104
+ return ret
0 commit comments