Skip to content

Commit e6f3dfc

Browse files
authored
Allow extra query params to be sent to the OpenAI server (#146)
Prior to the `openai_server` -> `openai_http` refactor (#91), we were using the `extra_query` parameter [in the OpenAI client](https://github.com/openai/openai-python/blob/fad098ffad7982a5150306a3d17f51ffef574f2e/src/openai/resources/models.py#L50) to send custom query parameters to the OpenAI server in requests made by guidellm. This PR adds that parameter to the new `OpenAIHTTPBackend`, making it possible to add custom query parameters that are included in every request sent to the server.
1 parent 442e7fa commit e6f3dfc

File tree

7 files changed

+72
-10
lines changed

7 files changed

+72
-10
lines changed

src/guidellm/backend/openai.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,24 @@
1717
)
1818
from guidellm.config import settings
1919

20-
__all__ = ["CHAT_COMPLETIONS_PATH", "TEXT_COMPLETIONS_PATH", "OpenAIHTTPBackend"]
20+
__all__ = [
21+
"CHAT_COMPLETIONS",
22+
"CHAT_COMPLETIONS_PATH",
23+
"MODELS",
24+
"TEXT_COMPLETIONS",
25+
"TEXT_COMPLETIONS_PATH",
26+
"OpenAIHTTPBackend",
27+
]
2128

2229

2330
TEXT_COMPLETIONS_PATH = "/v1/completions"
2431
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"
2532

33+
EndpointType = Literal["chat_completions", "models", "text_completions"]
34+
CHAT_COMPLETIONS: EndpointType = "chat_completions"
35+
MODELS: EndpointType = "models"
36+
TEXT_COMPLETIONS: EndpointType = "text_completions"
37+
2638

2739
@Backend.register("openai_http")
2840
class OpenAIHTTPBackend(Backend):
@@ -53,6 +65,11 @@ class OpenAIHTTPBackend(Backend):
5365
If not provided, the default value from settings is used.
5466
:param max_output_tokens: The maximum number of tokens to request for completions.
5567
If not provided, the default maximum tokens provided from settings is used.
68+
:param extra_query: Query parameters to include in requests to the OpenAI server.
69+
If "chat_completions", "models", or "text_completions" are included as keys,
70+
the values of these keys will be used as the parameters for the respective
71+
endpoint.
72+
If not provided, no extra query parameters are added.
5673
"""
5774

5875
def __init__(
@@ -66,6 +83,7 @@ def __init__(
6683
http2: Optional[bool] = True,
6784
follow_redirects: Optional[bool] = None,
6885
max_output_tokens: Optional[int] = None,
86+
extra_query: Optional[dict] = None,
6987
):
7088
super().__init__(type_="openai_http")
7189
self._target = target or settings.openai.base_url
@@ -101,6 +119,7 @@ def __init__(
101119
if max_output_tokens is not None
102120
else settings.openai.max_output_tokens
103121
)
122+
self.extra_query = extra_query
104123
self._async_client: Optional[httpx.AsyncClient] = None
105124

106125
@property
@@ -174,7 +193,10 @@ async def available_models(self) -> list[str]:
174193
"""
175194
target = f"{self.target}/v1/models"
176195
headers = self._headers()
177-
response = await self._get_async_client().get(target, headers=headers)
196+
params = self._params(MODELS)
197+
response = await self._get_async_client().get(
198+
target, headers=headers, params=params
199+
)
178200
response.raise_for_status()
179201

180202
models = []
@@ -219,6 +241,7 @@ async def text_completions( # type: ignore[override]
219241
)
220242

221243
headers = self._headers()
244+
params = self._params(TEXT_COMPLETIONS)
222245
payload = self._completions_payload(
223246
orig_kwargs=kwargs,
224247
max_output_tokens=output_token_count,
@@ -232,14 +255,16 @@ async def text_completions( # type: ignore[override]
232255
request_prompt_tokens=prompt_token_count,
233256
request_output_tokens=output_token_count,
234257
headers=headers,
258+
params=params,
235259
payload=payload,
236260
):
237261
yield resp
238262
except Exception as ex:
239263
logger.error(
240-
"{} request with headers: {} and payload: {} failed: {}",
264+
"{} request with headers: {} and params: {} and payload: {} failed: {}",
241265
self.__class__.__name__,
242266
headers,
267+
params,
243268
payload,
244269
ex,
245270
)
@@ -291,6 +316,7 @@ async def chat_completions( # type: ignore[override]
291316
"""
292317
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
293318
headers = self._headers()
319+
params = self._params(CHAT_COMPLETIONS)
294320
messages = (
295321
content if raw_content else self._create_chat_messages(content=content)
296322
)
@@ -307,14 +333,16 @@ async def chat_completions( # type: ignore[override]
307333
request_prompt_tokens=prompt_token_count,
308334
request_output_tokens=output_token_count,
309335
headers=headers,
336+
params=params,
310337
payload=payload,
311338
):
312339
yield resp
313340
except Exception as ex:
314341
logger.error(
315-
"{} request with headers: {} and payload: {} failed: {}",
342+
"{} request with headers: {} and params: {} and payload: {} failed: {}",
316343
self.__class__.__name__,
317344
headers,
345+
params,
318346
payload,
319347
ex,
320348
)
@@ -355,6 +383,19 @@ def _headers(self) -> dict[str, str]:
355383

356384
return headers
357385

386+
def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
387+
if self.extra_query is None:
388+
return {}
389+
390+
if (
391+
CHAT_COMPLETIONS in self.extra_query
392+
or MODELS in self.extra_query
393+
or TEXT_COMPLETIONS in self.extra_query
394+
):
395+
return self.extra_query.get(endpoint_type, {})
396+
397+
return self.extra_query
398+
358399
def _completions_payload(
359400
self, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs
360401
) -> dict:
@@ -451,8 +492,9 @@ async def _iterative_completions_request(
451492
request_id: Optional[str],
452493
request_prompt_tokens: Optional[int],
453494
request_output_tokens: Optional[int],
454-
headers: dict,
455-
payload: dict,
495+
headers: dict[str, str],
496+
params: dict[str, str],
497+
payload: dict[str, Any],
456498
) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]:
457499
if type_ == "text_completions":
458500
target = f"{self.target}{TEXT_COMPLETIONS_PATH}"
@@ -463,14 +505,16 @@ async def _iterative_completions_request(
463505

464506
logger.info(
465507
"{} making request: {} to target: {} using http2: {} following "
466-
"redirects: {} for timeout: {} with headers: {} and payload: {}",
508+
"redirects: {} for timeout: {} with headers: {} and params: {} and ",
509+
"payload: {}",
467510
self.__class__.__name__,
468511
request_id,
469512
target,
470513
self.http2,
471514
self.follow_redirects,
472515
self.timeout,
473516
headers,
517+
params,
474518
payload,
475519
)
476520

@@ -498,7 +542,7 @@ async def _iterative_completions_request(
498542
start_time = time.time()
499543

500544
async with self._get_async_client().stream(
501-
"POST", target, headers=headers, json=payload
545+
"POST", target, headers=headers, params=params, json=payload
502546
) as stream:
503547
stream.raise_for_status()
504548

@@ -542,10 +586,12 @@ async def _iterative_completions_request(
542586
response_output_count = usage["output"]
543587

544588
logger.info(
545-
"{} request: {} with headers: {} and payload: {} completed with: {}",
589+
"{} request: {} with headers: {} and params: {} and payload: {} completed"
590+
"with: {}",
546591
self.__class__.__name__,
547592
request_id,
548593
headers,
594+
params,
549595
payload,
550596
response_value,
551597
)
@@ -555,6 +601,7 @@ async def _iterative_completions_request(
555601
request_args=RequestArgs(
556602
target=target,
557603
headers=headers,
604+
params=params,
558605
payload=payload,
559606
timeout=self.timeout,
560607
http2=self.http2,

src/guidellm/backend/response.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class RequestArgs(StandardBaseModel):
4848
4949
:param target: The target URL or function for the request.
5050
:param headers: The headers, if any, included in the request such as authorization.
51+
:param params: The query parameters, if any, included in the request.
5152
:param payload: The payload / arguments for the request including the prompt /
5253
content and other configurations.
5354
:param timeout: The timeout for the request in seconds, if any.
@@ -57,6 +58,7 @@ class RequestArgs(StandardBaseModel):
5758

5859
target: str
5960
headers: dict[str, str]
61+
params: dict[str, str]
6062
payload: dict[str, Any]
6163
timeout: Optional[float] = None
6264
http2: Optional[bool] = None

src/guidellm/dataset/synthetic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
200200

201201
class SyntheticDatasetCreator(DatasetCreator):
202202
@classmethod
203-
def is_supported(cls, data: Any, data_args: Optional[dict[str, Any]]) -> bool: # noqa: ARG003
203+
def is_supported(
204+
cls, data: Any, data_args: Optional[dict[str, Any]] # noqa: ARG003
205+
) -> bool:
204206
if (
205207
isinstance(data, Path)
206208
and data.exists()

src/guidellm/scheduler/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def _handle_response(
475475
request_args=RequestArgs(
476476
target=self.backend.target,
477477
headers={},
478+
params={},
478479
payload={},
479480
),
480481
start_time=resolve_start_time,
@@ -490,6 +491,7 @@ def _handle_response(
490491
request_args=RequestArgs(
491492
target=self.backend.target,
492493
headers={},
494+
params={},
493495
payload={},
494496
),
495497
start_time=response.start_time,

tests/unit/backend/test_openai_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_openai_http_backend_default_initialization():
1818
assert backend.http2 is True
1919
assert backend.follow_redirects is True
2020
assert backend.max_output_tokens == settings.openai.max_output_tokens
21+
assert backend.extra_query is None
2122

2223

2324
@pytest.mark.smoke
@@ -32,6 +33,7 @@ def test_openai_http_backend_intialization():
3233
http2=False,
3334
follow_redirects=False,
3435
max_output_tokens=100,
36+
extra_query={"foo": "bar"},
3537
)
3638
assert backend.target == "http://test-target"
3739
assert backend.model == "test-model"
@@ -42,6 +44,7 @@ def test_openai_http_backend_intialization():
4244
assert backend.http2 is False
4345
assert backend.follow_redirects is False
4446
assert backend.max_output_tokens == 100
47+
assert backend.extra_query == {"foo": "bar"}
4548

4649

4750
@pytest.mark.smoke

tests/unit/backend/test_response.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def test_request_args_default_initialization():
7676
args = RequestArgs(
7777
target="http://example.com",
7878
headers={},
79+
params={},
7980
payload={},
8081
)
8182
assert args.timeout is None
@@ -90,6 +91,7 @@ def test_request_args_initialization():
9091
headers={
9192
"Authorization": "Bearer token",
9293
},
94+
params={},
9395
payload={
9496
"query": "Hello, world!",
9597
},
@@ -110,6 +112,7 @@ def test_response_args_marshalling():
110112
args = RequestArgs(
111113
target="http://example.com",
112114
headers={"Authorization": "Bearer token"},
115+
params={},
113116
payload={"query": "Hello, world!"},
114117
timeout=10.0,
115118
http2=True,
@@ -128,6 +131,7 @@ def test_response_summary_default_initialization():
128131
request_args=RequestArgs(
129132
target="http://example.com",
130133
headers={},
134+
params={},
131135
payload={},
132136
),
133137
start_time=0.0,
@@ -158,6 +162,7 @@ def test_response_summary_initialization():
158162
request_args=RequestArgs(
159163
target="http://example.com",
160164
headers={},
165+
params={},
161166
payload={},
162167
),
163168
start_time=1.0,

tests/unit/mock_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ async def _text_prompt_response_generator(
142142
request_args=RequestArgs(
143143
target=self.target,
144144
headers={},
145+
params={},
145146
payload={"prompt": prompt, "output_token_count": output_token_count},
146147
),
147148
iterations=len(tokens),

0 commit comments

Comments
 (0)