Skip to content

Commit 0f299e8

Browse files
authored
Update the AQUA documentation and the AQUA OpenAI client to support multiple inference endpoints in OCI Model Deployment. (#1212)
1 parent be35fa7 commit 0f299e8

File tree

4 files changed

+91
-32
lines changed

4 files changed

+91
-32
lines changed

ads/aqua/client/client.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
6161

6262
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the HttpxOCIAuth instance.
64+
Initializes the authentication handler with the given or default OCI signer.
6565
66-
Args:
67-
signer (oci.signer.Signer): The OCI signer to use for signing requests.
66+
Parameters
67+
----------
68+
signer : oci.signer.Signer, optional
69+
The OCI signer instance to use. If None, a default signer will be retrieved.
6870
"""
69-
70-
self.signer = signer or authutil.default_signer().get("signer")
71+
try:
72+
self.signer = signer or authutil.default_signer().get("signer")
73+
if not self.signer:
74+
raise ValueError("OCI signer could not be initialized.")
75+
except Exception as e:
76+
logger.error("Failed to initialize OCI signer: %s", e)
77+
raise
7178

7279
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7380
"""
@@ -80,21 +87,31 @@ def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
8087
httpx.Request: The signed HTTPX request.
8188
"""
8289
# Create a requests.Request object from the HTTPX request
83-
req = requests.Request(
84-
method=request.method,
85-
url=str(request.url),
86-
headers=dict(request.headers),
87-
data=request.content,
88-
)
89-
prepared_request = req.prepare()
90+
try:
91+
req = requests.Request(
92+
method=request.method,
93+
url=str(request.url),
94+
headers=dict(request.headers),
95+
data=request.content,
96+
)
97+
prepared_request = req.prepare()
98+
self.signer.do_request_sign(prepared_request)
99+
100+
# Replace headers on the original HTTPX request with signed headers
101+
request.headers.update(prepared_request.headers)
102+
logger.debug("Successfully signed request to %s", request.url)
90103

91-
# Sign the request using the OCI Signer
92-
self.signer.do_request_sign(prepared_request)
104+
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105+
if (
106+
request.method in ["GET", "DELETE"]
107+
and "content-length" not in request.headers
108+
):
109+
request.headers["content-length"] = "0"
93110

94-
# Update the original HTTPX request with the signed headers
95-
request.headers.update(prepared_request.headers)
111+
except Exception as e:
112+
logger.error("Failed to sign request to %s: %s", request.url, e)
113+
raise
96114

97-
# Proceed with the request
98115
yield request
99116

100117

@@ -330,8 +347,8 @@ def _prepare_headers(
330347
"Content-Type": "application/json",
331348
"Accept": "text/event-stream" if stream else "application/json",
332349
}
333-
if stream:
334-
default_headers["enable-streaming"] = "true"
350+
# if stream:
351+
# default_headers["enable-streaming"] = "true"
335352
if headers:
336353
default_headers.update(headers)
337354

@@ -495,7 +512,7 @@ def generate(
495512
prompt: str,
496513
payload: Optional[Dict[str, Any]] = None,
497514
headers: Optional[Dict[str, str]] = None,
498-
stream: bool = True,
515+
stream: bool = False,
499516
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500517
"""
501518
Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ def chat(
521538
messages: List[Dict[str, Any]],
522539
payload: Optional[Dict[str, Any]] = None,
523540
headers: Optional[Dict[str, str]] = None,
524-
stream: bool = True,
541+
stream: bool = False,
525542
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526543
"""
527544
Perform a chat interaction with the model.

ads/aqua/client/openai_client.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
3232
"""Supported base endpoints for model deployments."""
3333

3434
PREDICT = "predict"
35-
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
35+
PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"
3636

3737

3838
class AquaOpenAIMixin:
@@ -51,9 +51,9 @@ def _patch_route(self, original_path: str) -> str:
5151
Returns:
5252
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
5353
"""
54-
normalized_path = original_path.lower().rstrip("/")
54+
normalized_path = original_path.rstrip("/")
5555

56-
match = re.search(r"/predict(withresponsestream)?", normalized_path)
56+
match = re.search(r"/predict(WithResponseStream)?", normalized_path)
5757
if not match:
5858
logger.debug("Route header cannot be resolved from path: %s", original_path)
5959
return ""
@@ -71,7 +71,7 @@ def _patch_route(self, original_path: str) -> str:
7171
"Route suffix does not start with a version prefix (e.g., '/v1'). "
7272
"This may lead to compatibility issues with OpenAI-style endpoints. "
7373
"Consider updating the URL to include a version prefix, "
74-
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
74+
"such as '/predict/v1' or '/predictWithResponseStream/v1'."
7575
)
7676
# route_suffix = f"v1/{route_suffix}"
7777

@@ -124,13 +124,13 @@ def _patch_headers(self, request: httpx.Request) -> None:
124124

125125
def _patch_url(self) -> httpx.URL:
126126
"""
127-
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
127+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.
128128
129129
Returns:
130130
httpx.URL: The normalized base URL with the correct model deployment path.
131131
"""
132-
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133-
match = re.search(r"/predict(withresponsestream)?/", base_path)
132+
base_path = f"{self.base_url.path.rstrip('/')}/"
133+
match = re.search(r"/predict(WithResponseStream)?/", base_path)
134134
if match:
135135
trimmed = base_path[: match.end() - 1]
136136
return self.base_url.copy_with(path=trimmed)
@@ -144,7 +144,7 @@ def _prepare_request_common(self, request: httpx.Request) -> None:
144144
145145
This includes:
146146
- Patching headers with streaming and routing info.
147-
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
147+
- Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.
148148
149149
Args:
150150
request (httpx.Request): The outgoing HTTPX request.
@@ -176,6 +176,7 @@ def __init__(
176176
http_client: Optional[httpx.Client] = None,
177177
http_client_kwargs: Optional[Dict[str, Any]] = None,
178178
_strict_response_validation: bool = False,
179+
patch_headers: bool = False,
179180
**kwargs: Any,
180181
) -> None:
181182
"""
@@ -196,6 +197,7 @@ def __init__(
196197
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197198
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198199
_strict_response_validation (bool, optional): Enable strict response validation.
200+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
199201
**kwargs: Additional keyword arguments passed to the parent __init__.
200202
"""
201203
if http_client is None:
@@ -207,6 +209,8 @@ def __init__(
207209
logger.debug("API key not provided; using default placeholder for OCI.")
208210
api_key = "OCI"
209211

212+
self.patch_headers = patch_headers
213+
210214
super().__init__(
211215
api_key=api_key,
212216
organization=organization,
@@ -229,7 +233,8 @@ def _prepare_request(self, request: httpx.Request) -> None:
229233
Args:
230234
request (httpx.Request): The outgoing HTTP request.
231235
"""
232-
self._prepare_request_common(request)
236+
if self.patch_headers:
237+
self._prepare_request_common(request)
233238

234239

235240
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
@@ -248,6 +253,7 @@ def __init__(
248253
http_client: Optional[httpx.Client] = None,
249254
http_client_kwargs: Optional[Dict[str, Any]] = None,
250255
_strict_response_validation: bool = False,
256+
patch_headers: bool = False,
251257
**kwargs: Any,
252258
) -> None:
253259
"""
@@ -269,6 +275,7 @@ def __init__(
269275
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270276
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271277
_strict_response_validation (bool, optional): Enable strict response validation.
278+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
272279
**kwargs: Additional keyword arguments passed to the parent __init__.
273280
"""
274281
if http_client is None:
@@ -280,6 +287,8 @@ def __init__(
280287
logger.debug("API key not provided; using default placeholder for OCI.")
281288
api_key = "OCI"
282289

290+
self.patch_headers = patch_headers
291+
283292
super().__init__(
284293
api_key=api_key,
285294
organization=organization,
@@ -302,4 +311,5 @@ async def _prepare_request(self, request: httpx.Request) -> None:
302311
Args:
303312
request (httpx.Request): The outgoing HTTP request.
304313
"""
305-
self._prepare_request_common(request)
314+
if self.patch_headers:
315+
self._prepare_request_common(request)

docs/source/user_guide/large_language_model/aqua_client.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,35 @@ The asynchronous client, ``AsynOpenAI``, extends the AsyncOpenAI client. If no a
277277
print(event)
278278
279279
asyncio.run(test_async())
280+
281+
282+
Using the Native OpenAI Client
283+
------------------------------
284+
285+
If you prefer to use the **original `openai.OpenAI` client**, you must manually provide:
286+
287+
- A custom HTTP client created via `ads.aqua.get_httpx_client()`, and
288+
- `api_key="OCI"` (required for SDK compatibility).
289+
290+
.. code-block:: python
291+
292+
import ads
293+
from openai import OpenAI
294+
295+
ads.set_auth(auth="security_token")
296+
297+
# Create the patched HTTP client with OCI signer
298+
http_client = ads.aqua.get_httpx_client()
299+
300+
client = OpenAI(
301+
api_key="OCI",
302+
base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<OCID>/predict/v1",
303+
http_client=http_client
304+
)
305+
306+
response = client.chat.completions.create(
307+
model="odsc-llm",
308+
messages=[{"role": "user", "content": "Write a short story about a unicorn."}],
309+
)
310+
311+
print(response)

tests/unitary/with_extras/aqua/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_prepare_headers_stream(self):
257257
expected_headers = {
258258
"Content-Type": "application/json",
259259
"Accept": "text/event-stream",
260-
"enable-streaming": "true",
260+
# "enable-streaming": "true",
261261
"Custom-Header": "Value",
262262
}
263263
assert result == expected_headers

0 commit comments

Comments
 (0)