Skip to content

Commit c74eebd

Browse files
committed
Adds HttpxOCIAuth AQUA client integration section to documentation.
1 parent 35c03c6 commit c74eebd

File tree

5 files changed

+88
-28
lines changed

5 files changed

+88
-28
lines changed

ads/aqua/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from logging import getLogger
88

99
from ads import logger, set_auth
10-
from ads.aqua.client.client import AsyncClient, Client
10+
from ads.aqua.client.client import AsyncClient, Client, HttpxOCIAuth
1111
from ads.aqua.common.utils import fetch_service_compartment
1212
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1313

ads/aqua/client/client.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,23 @@
5151
logger = logging.getLogger(__name__)
5252

5353

54-
class OCIAuth(httpx.Auth):
54+
class HttpxOCIAuth(httpx.Auth):
5555
"""
5656
Custom HTTPX authentication class that uses the OCI Signer for request signing.
5757
5858
Attributes:
5959
signer (oci.signer.Signer): The OCI signer used to sign requests.
6060
"""
6161

62-
def __init__(self, signer: oci.signer.Signer):
62+
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the OCIAuth instance.
64+
Initialize the HttpxOCIAuth instance.
6565
6666
Args:
6767
signer (oci.signer.Signer): The OCI signer to use for signing requests.
6868
"""
69-
self.signer = signer
69+
70+
self.signer = signer or authutil.default_signer().get("signer")
7071

7172
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7273
"""
@@ -256,7 +257,7 @@ def __init__(
256257
auth = auth or authutil.default_signer()
257258
if not callable(auth.get("signer")):
258259
raise ValueError("Auth object must have a 'signer' callable attribute.")
259-
self.auth = OCIAuth(auth["signer"])
260+
self.auth = HttpxOCIAuth(auth["signer"])
260261

261262
logger.debug(
262263
f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, "
@@ -352,7 +353,7 @@ def __init__(self, *args, **kwargs) -> None:
352353
**kwargs: Keyword arguments forwarded to BaseClient.
353354
"""
354355
super().__init__(*args, **kwargs)
355-
self._client = httpx.Client(timeout=self.timeout)
356+
self._client = httpx.Client(timeout=self.timeout, auth=self.auth)
356357

357358
def is_closed(self) -> bool:
358359
return self._client.is_closed
@@ -400,7 +401,6 @@ def _request(
400401
response = self._client.post(
401402
self.endpoint,
402403
headers=self._prepare_headers(stream=False, headers=headers),
403-
auth=self.auth,
404404
json=payload,
405405
)
406406
logger.debug(f"Received response with status code: {response.status_code}")
@@ -447,7 +447,6 @@ def _stream(
447447
"POST",
448448
self.endpoint,
449449
headers=self._prepare_headers(stream=True, headers=headers),
450-
auth=self.auth,
451450
json={**payload, "stream": True},
452451
) as response:
453452
try:
@@ -581,7 +580,7 @@ def __init__(self, *args, **kwargs) -> None:
581580
**kwargs: Keyword arguments forwarded to BaseClient.
582581
"""
583582
super().__init__(*args, **kwargs)
584-
self._client = httpx.AsyncClient(timeout=self.timeout)
583+
self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth)
585584

586585
def is_closed(self) -> bool:
587586
return self._client.is_closed
@@ -637,7 +636,6 @@ async def _request(
637636
response = await self._client.post(
638637
self.endpoint,
639638
headers=self._prepare_headers(stream=False, headers=headers),
640-
auth=self.auth,
641639
json=payload,
642640
)
643641
logger.debug(f"Received response with status code: {response.status_code}")
@@ -683,7 +681,6 @@ async def _stream(
683681
"POST",
684682
self.endpoint,
685683
headers=self._prepare_headers(stream=True, headers=headers),
686-
auth=self.auth,
687684
json={**payload, "stream": True},
688685
) as response:
689686
try:

ads/common/auth.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ads.common import logger
2121
from ads.common.decorator.deprecate import deprecated
2222
from ads.common.extended_enum import ExtendedEnum
23+
from ads.config import OCI_ODSC_SERVICE_ENDPOINT
2324

2425
SECURITY_TOKEN_LEFT_TIME = 600
2526

@@ -88,13 +89,15 @@ def set_auth(
8889
auth: Optional[str] = AuthType.API_KEY,
8990
oci_config_location: Optional[str] = DEFAULT_LOCATION,
9091
profile: Optional[str] = DEFAULT_PROFILE,
91-
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
92-
if os.environ.get("OCI_RESOURCE_REGION")
93-
else {},
92+
config: Optional[Dict] = (
93+
{"region": os.environ["OCI_RESOURCE_REGION"]}
94+
if os.environ.get("OCI_RESOURCE_REGION")
95+
else {}
96+
),
9497
signer: Optional[Any] = None,
9598
signer_callable: Optional[Callable] = None,
96-
signer_kwargs: Optional[Dict] = {},
97-
client_kwargs: Optional[Dict] = {},
99+
signer_kwargs: Optional[Dict] = None,
100+
client_kwargs: Optional[Dict] = None,
98101
) -> None:
99102
"""
100103
Sets the default authentication type.
@@ -195,6 +198,12 @@ def set_auth(
195198
>>> # instance principals authentication dictionary created based on callable with kwargs parameters:
196199
>>> ads.set_auth(signer_callable=signer_callable, signer_kwargs=signer_kwargs)
197200
"""
201+
signer_kwargs = signer_kwargs or {}
202+
client_kwargs = {
203+
"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT,
204+
**(client_kwargs or {}),
205+
}
206+
198207
auth_state = AuthState()
199208

200209
valid_auth_keys = AuthFactory.classes.keys()
@@ -258,9 +267,11 @@ def api_keys(
258267
"""
259268
signer_args = dict(
260269
oci_config=oci_config if isinstance(oci_config, Dict) else {},
261-
oci_config_location=oci_config
262-
if isinstance(oci_config, str)
263-
else os.path.expanduser(DEFAULT_LOCATION),
270+
oci_config_location=(
271+
oci_config
272+
if isinstance(oci_config, str)
273+
else os.path.expanduser(DEFAULT_LOCATION)
274+
),
264275
oci_key_profile=profile,
265276
client_kwargs=client_kwargs,
266277
)
@@ -334,9 +345,11 @@ def security_token(
334345
"""
335346
signer_args = dict(
336347
oci_config=oci_config if isinstance(oci_config, Dict) else {},
337-
oci_config_location=oci_config
338-
if isinstance(oci_config, str)
339-
else os.path.expanduser(DEFAULT_LOCATION),
348+
oci_config_location=(
349+
oci_config
350+
if isinstance(oci_config, str)
351+
else os.path.expanduser(DEFAULT_LOCATION)
352+
),
340353
oci_key_profile=profile,
341354
client_kwargs=client_kwargs,
342355
)
@@ -485,6 +498,7 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
485498
>>> auth = ads.auth.default_signer() # signer_callable instantiated
486499
>>> oc.OCIClientFactory(**auth).object_storage # Creates Object storage client using instance principal authentication
487500
"""
501+
default_client_kwargs = {"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT}
488502
auth_state = AuthState()
489503
if auth_state.oci_signer or auth_state.oci_signer_callable:
490504
configuration = ads.telemetry.update_oci_client_config(auth_state.oci_config)
@@ -496,6 +510,7 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
496510
"config": configuration,
497511
"signer": signer,
498512
"client_kwargs": {
513+
**default_client_kwargs,
499514
**(auth_state.oci_client_kwargs or {}),
500515
**(client_kwargs or {}),
501516
},
@@ -508,6 +523,7 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
508523
oci_key_profile=auth_state.oci_key_profile,
509524
oci_config=auth_state.oci_config,
510525
client_kwargs={
526+
**default_client_kwargs,
511527
**(auth_state.oci_client_kwargs or {}),
512528
**(client_kwargs or {}),
513529
},

docs/source/user_guide/large_language_model/aqua_client.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,50 @@ The following examples demonstrate how to perform the same operations using the
129129
input=["one", "two"]
130130
)
131131
print(response)
132+
133+
134+
HTTPX Client Integration with OCI Authentication (HttpxOCIAuth)
135+
================================================================
136+
137+
.. versionadded:: 2.13.1
138+
139+
In recent updates to the client, a new class, ``HttpxOCIAuth``, has been introduced.
140+
This class allows signing of HTTPX requests using OCI signers, making the HTTPX client compatible
141+
with the LLM models deployed on the OCI Model Deployment service. With this integration, you can use
142+
HTTPX-based clients with any compatible third-party libraries (e.g., the OpenAi client).
143+
144+
Usage
145+
-----
146+
147+
**Synchronous HTTPX Client**
148+
149+
.. code-block:: python3
150+
151+
import httpx
152+
import ads
153+
from ads.aqua import HttpxOCIAuth
154+
155+
ads.set_auth(auth="security_token", profile="<replace-with-your-profile>")
156+
client = httpx.Client(auth=HttpxOCIAuth())
157+
158+
response = client.post(
159+
url="https://<MD_OCID>/predict",
160+
json={
161+
"model": "odsc-llm",
162+
"prompt": "Tell me a joke."
163+
},
164+
)
165+
166+
response.raise_for_status()
167+
json_response = response.json()
168+
169+
**Asynchronous HTTPX Client**
170+
171+
.. code-block:: python3
172+
173+
import httpx
174+
import ads
175+
from ads.aqua import HttpxOCIAuth
176+
177+
ads.set_auth(auth="security_token", profile="<replace-with-your-profile>")
178+
async_client = httpx.AsyncClient(auth=HttpxOCIAuth())

tests/unitary/with_extras/aqua/test_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515
BaseClient,
1616
Client,
1717
ExtendedRequestError,
18-
OCIAuth,
18+
HttpxOCIAuth,
1919
_create_retry_decorator,
2020
_retry_decorator,
2121
_should_retry_exception,
2222
)
2323
from ads.common import auth as authutil
2424

2525

26-
class TestOCIAuth:
27-
"""Unit tests for OCIAuth class."""
26+
class TestHttpxOCIAuth:
27+
"""Unit tests for HttpxOCIAuth class."""
2828

2929
def setup_method(self):
3030
self.signer_mock = Mock()
31-
self.oci_auth = OCIAuth(self.signer_mock)
31+
self.oci_auth = HttpxOCIAuth(self.signer_mock)
3232

3333
def test_auth_flow(self):
3434
"""Ensures that the auth_flow signs the request correctly."""
@@ -226,7 +226,7 @@ def test_init(self):
226226
assert self.base_client.retries == self.retries
227227
assert self.base_client.backoff_factor == self.backoff_factor
228228
assert self.base_client.timeout == self.timeout
229-
assert isinstance(self.base_client.auth, OCIAuth)
229+
assert isinstance(self.base_client.auth, HttpxOCIAuth)
230230

231231
def test_init_default_auth(self):
232232
"""Ensures that default auth is used when auth is None."""

0 commit comments

Comments
 (0)