17
17
)
18
18
from guidellm .config import settings
19
19
20
- __all__ = ["CHAT_COMPLETIONS_PATH" , "TEXT_COMPLETIONS_PATH" , "OpenAIHTTPBackend" ]
20
+ __all__ = [
21
+ "CHAT_COMPLETIONS" ,
22
+ "CHAT_COMPLETIONS_PATH" ,
23
+ "MODELS" ,
24
+ "TEXT_COMPLETIONS" ,
25
+ "TEXT_COMPLETIONS_PATH" ,
26
+ "OpenAIHTTPBackend" ,
27
+ ]
21
28
22
29
23
30
TEXT_COMPLETIONS_PATH = "/v1/completions"
24
31
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"
25
32
33
+ EndpointType = Literal ["chat_completions" , "models" , "text_completions" ]
34
+ CHAT_COMPLETIONS : EndpointType = "chat_completions"
35
+ MODELS : EndpointType = "models"
36
+ TEXT_COMPLETIONS : EndpointType = "text_completions"
37
+
26
38
27
39
@Backend .register ("openai_http" )
28
40
class OpenAIHTTPBackend (Backend ):
@@ -53,6 +65,11 @@ class OpenAIHTTPBackend(Backend):
53
65
If not provided, the default value from settings is used.
54
66
:param max_output_tokens: The maximum number of tokens to request for completions.
55
67
If not provided, the default maximum tokens provided from settings is used.
68
+ :param extra_query: Query parameters to include in requests to the OpenAI server.
69
+ If "chat_completions", "models", or "text_completions" are included as keys,
70
+ the values of these keys will be used as the parameters for the respective
71
+ endpoint.
72
+ If not provided, no extra query parameters are added.
56
73
"""
57
74
58
75
def __init__ (
@@ -66,6 +83,7 @@ def __init__(
66
83
http2 : Optional [bool ] = True ,
67
84
follow_redirects : Optional [bool ] = None ,
68
85
max_output_tokens : Optional [int ] = None ,
86
+ extra_query : Optional [dict ] = None ,
69
87
):
70
88
super ().__init__ (type_ = "openai_http" )
71
89
self ._target = target or settings .openai .base_url
@@ -101,6 +119,7 @@ def __init__(
101
119
if max_output_tokens is not None
102
120
else settings .openai .max_output_tokens
103
121
)
122
+ self .extra_query = extra_query
104
123
self ._async_client : Optional [httpx .AsyncClient ] = None
105
124
106
125
@property
@@ -174,7 +193,10 @@ async def available_models(self) -> list[str]:
174
193
"""
175
194
target = f"{ self .target } /v1/models"
176
195
headers = self ._headers ()
177
- response = await self ._get_async_client ().get (target , headers = headers )
196
+ params = self ._params (MODELS )
197
+ response = await self ._get_async_client ().get (
198
+ target , headers = headers , params = params
199
+ )
178
200
response .raise_for_status ()
179
201
180
202
models = []
@@ -219,6 +241,7 @@ async def text_completions( # type: ignore[override]
219
241
)
220
242
221
243
headers = self ._headers ()
244
+ params = self ._params (TEXT_COMPLETIONS )
222
245
payload = self ._completions_payload (
223
246
orig_kwargs = kwargs ,
224
247
max_output_tokens = output_token_count ,
@@ -232,14 +255,16 @@ async def text_completions( # type: ignore[override]
232
255
request_prompt_tokens = prompt_token_count ,
233
256
request_output_tokens = output_token_count ,
234
257
headers = headers ,
258
+ params = params ,
235
259
payload = payload ,
236
260
):
237
261
yield resp
238
262
except Exception as ex :
239
263
logger .error (
240
- "{} request with headers: {} and payload: {} failed: {}" ,
264
+ "{} request with headers: {} and params: {} and payload: {} failed: {}" ,
241
265
self .__class__ .__name__ ,
242
266
headers ,
267
+ params ,
243
268
payload ,
244
269
ex ,
245
270
)
@@ -291,6 +316,7 @@ async def chat_completions( # type: ignore[override]
291
316
"""
292
317
logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
293
318
headers = self ._headers ()
319
+ params = self ._params (CHAT_COMPLETIONS )
294
320
messages = (
295
321
content if raw_content else self ._create_chat_messages (content = content )
296
322
)
@@ -307,14 +333,16 @@ async def chat_completions( # type: ignore[override]
307
333
request_prompt_tokens = prompt_token_count ,
308
334
request_output_tokens = output_token_count ,
309
335
headers = headers ,
336
+ params = params ,
310
337
payload = payload ,
311
338
):
312
339
yield resp
313
340
except Exception as ex :
314
341
logger .error (
315
- "{} request with headers: {} and payload: {} failed: {}" ,
342
+ "{} request with headers: {} and params: {} and payload: {} failed: {}" ,
316
343
self .__class__ .__name__ ,
317
344
headers ,
345
+ params ,
318
346
payload ,
319
347
ex ,
320
348
)
@@ -355,6 +383,19 @@ def _headers(self) -> dict[str, str]:
355
383
356
384
return headers
357
385
386
+ def _params (self , endpoint_type : EndpointType ) -> dict [str , str ]:
387
+ if self .extra_query is None :
388
+ return {}
389
+
390
+ if (
391
+ CHAT_COMPLETIONS in self .extra_query
392
+ or MODELS in self .extra_query
393
+ or TEXT_COMPLETIONS in self .extra_query
394
+ ):
395
+ return self .extra_query .get (endpoint_type , {})
396
+
397
+ return self .extra_query
398
+
358
399
def _completions_payload (
359
400
self , orig_kwargs : Optional [dict ], max_output_tokens : Optional [int ], ** kwargs
360
401
) -> dict :
@@ -451,8 +492,9 @@ async def _iterative_completions_request(
451
492
request_id : Optional [str ],
452
493
request_prompt_tokens : Optional [int ],
453
494
request_output_tokens : Optional [int ],
454
- headers : dict ,
455
- payload : dict ,
495
+ headers : dict [str , str ],
496
+ params : dict [str , str ],
497
+ payload : dict [str , Any ],
456
498
) -> AsyncGenerator [Union [StreamingTextResponse , ResponseSummary ], None ]:
457
499
if type_ == "text_completions" :
458
500
target = f"{ self .target } { TEXT_COMPLETIONS_PATH } "
@@ -463,14 +505,16 @@ async def _iterative_completions_request(
463
505
464
506
logger .info (
465
507
"{} making request: {} to target: {} using http2: {} following "
466
- "redirects: {} for timeout: {} with headers: {} and payload: {}" ,
508
+ "redirects: {} for timeout: {} with headers: {} and params: {} and " ,
509
+ "payload: {}" ,
467
510
self .__class__ .__name__ ,
468
511
request_id ,
469
512
target ,
470
513
self .http2 ,
471
514
self .follow_redirects ,
472
515
self .timeout ,
473
516
headers ,
517
+ params ,
474
518
payload ,
475
519
)
476
520
@@ -498,7 +542,7 @@ async def _iterative_completions_request(
498
542
start_time = time .time ()
499
543
500
544
async with self ._get_async_client ().stream (
501
- "POST" , target , headers = headers , json = payload
545
+ "POST" , target , headers = headers , params = params , json = payload
502
546
) as stream :
503
547
stream .raise_for_status ()
504
548
@@ -542,10 +586,12 @@ async def _iterative_completions_request(
542
586
response_output_count = usage ["output" ]
543
587
544
588
logger .info (
545
- "{} request: {} with headers: {} and payload: {} completed with: {}" ,
589
+ "{} request: {} with headers: {} and params: {} and payload: {} completed"
590
+ "with: {}" ,
546
591
self .__class__ .__name__ ,
547
592
request_id ,
548
593
headers ,
594
+ params ,
549
595
payload ,
550
596
response_value ,
551
597
)
@@ -555,6 +601,7 @@ async def _iterative_completions_request(
555
601
request_args = RequestArgs (
556
602
target = target ,
557
603
headers = headers ,
604
+ params = params ,
558
605
payload = payload ,
559
606
timeout = self .timeout ,
560
607
http2 = self .http2 ,
0 commit comments