@@ -70,6 +70,14 @@ class OpenAIHTTPBackend(Backend):
70
70
the values of these keys will be used as the parameters for the respective
71
71
endpoint.
72
72
If not provided, no extra query parameters are added.
73
+ :param extra_body: Body parameters to include in requests to the OpenAI server.
74
+ If "chat_completions", "models", or "text_completions" are included as keys,
75
+ the values of these keys will be included in the body for the respective
76
+ endpoint.
77
+ If not provided, no extra body parameters are added.
78
+ :param remove_from_body: Parameters that should be removed from the body of each
79
+ request.
80
+ If not provided, no parameters are removed from the body.
73
81
"""
74
82
75
83
def __init__ (
@@ -85,6 +93,7 @@ def __init__(
85
93
max_output_tokens : Optional [int ] = None ,
86
94
extra_query : Optional [dict ] = None ,
87
95
extra_body : Optional [dict ] = None ,
96
+ remove_from_body : Optional [list [str ]] = None ,
88
97
):
89
98
super ().__init__ (type_ = "openai_http" )
90
99
self ._target = target or settings .openai .base_url
@@ -122,6 +131,7 @@ def __init__(
122
131
)
123
132
self .extra_query = extra_query
124
133
self .extra_body = extra_body
134
+ self .remove_from_body = remove_from_body
125
135
self ._async_client : Optional [httpx .AsyncClient ] = None
126
136
127
137
@property
@@ -253,9 +263,8 @@ async def text_completions( # type: ignore[override]
253
263
254
264
headers = self ._headers ()
255
265
params = self ._params (TEXT_COMPLETIONS )
256
- body = self ._body (TEXT_COMPLETIONS )
257
266
payload = self ._completions_payload (
258
- body = body ,
267
+ endpoint_type = TEXT_COMPLETIONS ,
259
268
orig_kwargs = kwargs ,
260
269
max_output_tokens = output_token_count ,
261
270
prompt = prompt ,
@@ -330,12 +339,11 @@ async def chat_completions( # type: ignore[override]
330
339
logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
331
340
headers = self ._headers ()
332
341
params = self ._params (CHAT_COMPLETIONS )
333
- body = self ._body (CHAT_COMPLETIONS )
334
342
messages = (
335
343
content if raw_content else self ._create_chat_messages (content = content )
336
344
)
337
345
payload = self ._completions_payload (
338
- body = body ,
346
+ endpoint_type = CHAT_COMPLETIONS ,
339
347
orig_kwargs = kwargs ,
340
348
max_output_tokens = output_token_count ,
341
349
messages = messages ,
@@ -411,7 +419,7 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
411
419
412
420
return self .extra_query
413
421
414
- def _body (self , endpoint_type : EndpointType ) -> dict [str , str ]:
422
+ def _extra_body (self , endpoint_type : EndpointType ) -> dict [str , Any ]:
415
423
if self .extra_body is None :
416
424
return {}
417
425
@@ -426,12 +434,12 @@ def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
426
434
427
435
def _completions_payload (
428
436
self ,
429
- body : Optional [ dict ] ,
437
+ endpoint_type : EndpointType ,
430
438
orig_kwargs : Optional [dict ],
431
439
max_output_tokens : Optional [int ],
432
440
** kwargs ,
433
441
) -> dict :
434
- payload = body or {}
442
+ payload = self . _extra_body ( endpoint_type )
435
443
payload .update (orig_kwargs or {})
436
444
payload .update (kwargs )
437
445
payload ["model" ] = self .model
@@ -455,6 +463,10 @@ def _completions_payload(
455
463
payload ["stop" ] = None
456
464
payload ["ignore_eos" ] = True
457
465
466
+ if self .remove_from_body :
467
+ for key in self .remove_from_body :
468
+ payload .pop (key , None )
469
+
458
470
return payload
459
471
460
472
@staticmethod
0 commit comments