1
1
import functools
2
2
from abc import ABC , abstractmethod
3
- from dataclasses import dataclass
4
- from enum import Enum
5
- from typing import Dict , Iterator , List , Optional , Type
3
+ from typing import AsyncGenerator , Dict , List , Literal , Optional , Type
6
4
7
5
from loguru import logger
6
+ from pydantic import BaseModel
8
7
9
8
from guidellm .core import TextGenerationRequest , TextGenerationResult
10
9
11
10
__all__ = ["Backend" , "BackendEngine" , "GenerativeResponse" ]
12
11
13
12
14
- class BackendEngine (str , Enum ):
15
- """
16
- Determines the Engine of the LLM Backend.
17
- All the implemented backends in the project have the engine.
18
-
19
- NOTE: the `TEST` engine has to be used only for testing purposes.
20
- """
13
+ BackendEngine = Literal ["test" , "openai_server" ]
21
14
22
- TEST = "test"
23
- OPENAI_SERVER = "openai_server"
24
15
25
-
26
- @dataclass
27
- class GenerativeResponse :
16
+ class GenerativeResponse (BaseModel ):
28
17
"""
29
- A dataclass to represent a response from a generative AI backend.
18
+ A model representing a response from a generative AI backend.
19
+
20
+ :param type_: The type of response, either 'token_iter' for intermediate
21
+ token output or 'final' for the final result.
22
+ :type type_: Literal["token_iter", "final"]
23
+ :param add_token: The token to add to the output
24
+ (only applicable if type_ is 'token_iter').
25
+ :type add_token: Optional[str]
26
+ :param prompt: The original prompt sent to the backend.
27
+ :type prompt: Optional[str]
28
+ :param output: The final generated output (only applicable if type_ is 'final').
29
+ :type output: Optional[str]
30
+ :param prompt_token_count: The number of tokens in the prompt.
31
+ :type prompt_token_count: Optional[int]
32
+ :param output_token_count: The number of tokens in the output.
33
+ :type output_token_count: Optional[int]
30
34
"""
31
35
32
- type_ : str # One of ' token_iter', ' final'
36
+ type_ : Literal [ " token_iter" , " final" ]
33
37
add_token : Optional [str ] = None
34
38
prompt : Optional [str ] = None
35
39
output : Optional [str ] = None
@@ -39,7 +43,14 @@ class GenerativeResponse:
39
43
40
44
class Backend (ABC ):
41
45
"""
42
- An abstract base class with template methods for generative AI backends.
46
+ Abstract base class for generative AI backends.
47
+
48
+ This class provides a common interface for creating and interacting with different
49
+ generative AI backends. Subclasses should implement the abstract methods to
50
+ define specific backend behavior.
51
+
52
+ :cvar _registry: A dictionary that maps BackendEngine types to backend classes.
53
+ :type _registry: Dict[BackendEngine, Type[Backend]]
43
54
"""
44
55
45
56
_registry : Dict [BackendEngine , "Type[Backend]" ] = {}
@@ -50,33 +61,38 @@ def register(cls, backend_type: BackendEngine):
50
61
A decorator to register a backend class in the backend registry.
51
62
52
63
:param backend_type: The type of backend to register.
53
- :type backend_type: BackendType
64
+ :type backend_type: BackendEngine
65
+ :return: The decorated backend class.
66
+ :rtype: Type[Backend]
54
67
"""
55
68
56
69
def inner_wrapper (wrapped_class : Type ["Backend" ]):
57
70
cls ._registry [backend_type ] = wrapped_class
71
+ logger .info ("Registered backend type: {}" , backend_type )
58
72
return wrapped_class
59
73
60
74
return inner_wrapper
61
75
62
76
@classmethod
63
77
def create (cls , backend_type : BackendEngine , ** kwargs ) -> "Backend" :
64
78
"""
65
- Factory method to create a backend based on the backend type.
79
+ Factory method to create a backend instance based on the backend type.
66
80
67
81
:param backend_type: The type of backend to create.
68
- :type backend_type: BackendType
82
+ :type backend_type: BackendEngine
69
83
:param kwargs: Additional arguments for backend initialization.
70
84
:type kwargs: dict
71
85
:return: An instance of a subclass of Backend.
72
86
:rtype: Backend
87
+ :raises ValueError: If the backend type is not registered.
73
88
"""
74
89
75
- logger .info (f "Creating backend of type { backend_type } " )
90
+ logger .info ("Creating backend of type {}" , backend_type )
76
91
77
92
if backend_type not in cls ._registry :
78
- logger .error (f"Unsupported backend type: { backend_type } " )
79
- raise ValueError (f"Unsupported backend type: { backend_type } " )
93
+ err = ValueError (f"Unsupported backend type: { backend_type } " )
94
+ logger .error ("{}" , err )
95
+ raise err
80
96
81
97
return Backend ._registry [backend_type ](** kwargs )
82
98
@@ -87,82 +103,119 @@ def default_model(self) -> str:
87
103
88
104
:return: The default model.
89
105
:rtype: str
106
+ :raises ValueError: If no models are available.
90
107
"""
91
108
return _cachable_default_model (self )
92
109
93
- def submit (self , request : TextGenerationRequest ) -> TextGenerationResult :
110
+ async def submit (self , request : TextGenerationRequest ) -> TextGenerationResult :
94
111
"""
95
- Submit a result request and populate the BenchmarkResult .
112
+ Submit a text generation request and return the result .
96
113
97
- :param request: The result request to submit.
114
+ This method handles the request submission to the backend and processes
115
+ the response in a streaming fashion if applicable.
116
+
117
+ :param request: The request object containing the prompt
118
+ and other configurations.
98
119
:type request: TextGenerationRequest
99
- :return: The populated result result .
120
+ :return: The result of the text generation request .
100
121
:rtype: TextGenerationResult
122
+ :raises ValueError: If no response is received from the backend.
101
123
"""
102
124
103
- logger .info ( f "Submitting request with prompt: { request .prompt } " )
125
+ logger .debug ( "Submitting request with prompt: {}" , request .prompt )
104
126
105
- result = TextGenerationResult (
106
- request = TextGenerationRequest (prompt = request .prompt ),
107
- )
127
+ result = TextGenerationResult (request = request )
108
128
result .start (request .prompt )
129
+ received_final = False
109
130
110
- for response in self .make_request (request ): # GenerativeResponse
111
- if response .type_ == "token_iter" and response .add_token :
112
- result .output_token (response .add_token )
131
+ async for response in self .make_request (request ):
132
+ logger .debug ("Received response: {}" , response )
133
+ if response .type_ == "token_iter" :
134
+ result .output_token (response .add_token if response .add_token else "" )
113
135
elif response .type_ == "final" :
136
+ if received_final :
137
+ err = ValueError (
138
+ "Received multiple final responses from the backend."
139
+ )
140
+ logger .error (err )
141
+ raise err
142
+
114
143
result .end (
144
+ output = response .output ,
115
145
prompt_token_count = response .prompt_token_count ,
116
146
output_token_count = response .output_token_count ,
117
147
)
148
+ received_final = True
149
+ else :
150
+ err = ValueError (
151
+ f"Invalid response received from the backend of type: "
152
+ f"{ response .type_ } for { response } "
153
+ )
154
+ logger .error (err )
155
+ raise err
118
156
119
- logger .info (f"Request completed with output: { result .output } " )
157
+ if not received_final :
158
+ err = ValueError ("No final response received from the backend." )
159
+ logger .error (err )
160
+ raise err
161
+
162
+ logger .info ("Request completed with output: {}" , result .output )
120
163
121
164
return result
122
165
123
166
@abstractmethod
124
- def make_request (
167
+ async def make_request (
125
168
self ,
126
169
request : TextGenerationRequest ,
127
- ) -> Iterator [GenerativeResponse ]:
170
+ ) -> AsyncGenerator [GenerativeResponse , None ]:
128
171
"""
129
172
Abstract method to make a request to the backend.
130
173
131
- :param request: The result request to submit.
174
+ Subclasses must implement this method to define how requests are handled
175
+ by the backend.
176
+
177
+ :param request: The request object containing the prompt and
178
+ other configurations.
132
179
:type request: TextGenerationRequest
133
- :return: An iterator over the generative responses .
134
- :rtype: Iterator [GenerativeResponse]
180
+ :yield: A generator yielding responses from the backend .
181
+ :rtype: AsyncGenerator [GenerativeResponse, None ]
135
182
"""
136
- raise NotImplementedError
183
+ yield None # type: ignore # noqa: PGH003
137
184
138
185
@abstractmethod
139
186
def available_models (self ) -> List [str ]:
140
187
"""
141
188
Abstract method to get the available models for the backend.
142
189
190
+ Subclasses must implement this method to provide the list of models
191
+ supported by the backend.
192
+
143
193
:return: A list of available models.
144
194
:rtype: List[str]
145
- """
146
- raise NotImplementedError
147
-
148
- @abstractmethod
149
- def model_tokenizer (self , model : str ) -> Optional [str ]:
150
- """
151
- Abstract method to get the tokenizer for a model.
152
-
153
- :param model: The model to get the tokenizer for.
154
- :type model: str
155
- :return: The tokenizer for the model, or None if it cannot be created.
156
- :rtype: Optional[str]
195
+ :raises NotImplementedError: If the method is not implemented by a subclass.
157
196
"""
158
197
raise NotImplementedError
159
198
160
199
161
200
@functools .lru_cache (maxsize = 1 )
162
201
def _cachable_default_model (backend : Backend ) -> str :
163
- if models := backend .available_models ():
164
- logger .debug (f"Default model: { models [0 ]} " )
202
+ """
203
+ Get the default model for a backend using LRU caching.
204
+
205
+ This function caches the default model to optimize repeated lookups.
206
+
207
+ :param backend: The backend instance for which to get the default model.
208
+ :type backend: Backend
209
+ :return: The default model.
210
+ :rtype: str
211
+ :raises ValueError: If no models are available.
212
+ """
213
+ logger .debug ("Getting default model for backend: {}" , backend )
214
+ models = backend .available_models ()
215
+ if models :
216
+ logger .debug ("Default model: {}" , models [0 ])
165
217
return models [0 ]
166
218
167
- logger .error ("No models available." )
168
- raise ValueError ("No models available." )
219
+ err = ValueError ("No models available." )
220
+ logger .error (err )
221
+ raise err
0 commit comments