Skip to content

Commit e31bc07

Browse files
committed
Update model deployment LLM model to refresh signer.
1 parent 70fcd76 commit e31bc07

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

ads/llm/langchain/plugins/llm_md.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
import requests
11+
from oci.auth import signers
1112
from langchain.callbacks.manager import CallbackManagerForLLMRun
1213

1314
from ads.llm.langchain.plugins.base import BaseLLM
@@ -134,11 +135,22 @@ def send_request(
134135
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
135136
or DEFAULT_CONTENT_TYPE_JSON
136137
)
138+
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
137139
request_kwargs = {"json": data}
138140
request_kwargs["headers"] = header
139-
request_kwargs["auth"] = self.auth.get("signer")
140-
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
141-
response = requests.post(endpoint, timeout=timeout, **request_kwargs, **kwargs)
141+
signer = self.auth.get("signer")
142+
143+
attempts = 0
144+
while attempts < 2:
145+
request_kwargs["auth"] = signer
146+
response = requests.post(
147+
endpoint, timeout=timeout, **request_kwargs, **kwargs
148+
)
149+
if response.status_code == 401 and self.is_principal_signer(signer):
150+
signer.refresh_security_token()
151+
attempts += 1
152+
continue
153+
break
142154

143155
try:
144156
response.raise_for_status()
@@ -155,6 +167,21 @@ def send_request(
155167

156168
return response_json
157169

170+
@staticmethod
171+
def is_principal_signer(signer):
172+
"""Checks if the signer is instance principal or resource principal signer."""
173+
if (
174+
isinstance(signer, signers.InstancePrincipalsSecurityTokenSigner)
175+
or isinstance(signer, signers.ResourcePrincipalsFederationSigner)
176+
or isinstance(signer, signers.EphemeralResourcePrincipalSigner)
177+
or isinstance(signer, signers.EphemeralResourcePrincipalV21Signer)
178+
or isinstance(signer, signers.NestedResourcePrincipals)
179+
or isinstance(signer, signers.OkeWorkloadIdentityResourcePrincipalSigner)
180+
):
181+
return True
182+
else:
183+
return False
184+
158185

159186
class ModelDeploymentTGI(ModelDeploymentLLM):
160187
"""OCI Data Science Model Deployment TGI Endpoint.

0 commit comments

Comments
 (0)