Skip to content

Commit 872e41b

Browse files
committed
Updated pr.
1 parent 25c26f2 commit 872e41b

File tree

4 files changed

+52
-17
lines changed

4 files changed

+52
-17
lines changed

ads/llm/langchain/plugins/chat_models/oci_data_science.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
9393
Key init args — client params:
9494
auth: dict
9595
ADS auth dictionary for OCI authentication.
96-
headers: Optional[Dict]
96+
default_headers: Optional[Dict]
9797
The headers to be added to the Model Deployment request.
9898
9999
Instantiate:
@@ -111,7 +111,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
111111
"temperature": 0.2,
112112
# other model parameters ...
113113
},
114-
headers={
114+
default_headers={
115115
"route": "/v1/chat/completions",
116116
# other request headers ...
117117
},
@@ -263,9 +263,6 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
263263
"""Stop words to use when generating. Model output is cut off
264264
at the first occurrence of any of these substrings."""
265265

266-
headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
267-
"""The headers to be added to the Model Deployment request."""
268-
269266
@model_validator(mode="before")
270267
@classmethod
271268
def validate_openai(cls, values: Any) -> Any:
@@ -300,6 +297,25 @@ def _default_params(self) -> Dict[str, Any]:
300297
"stream": self.streaming,
301298
}
302299

300+
def _headers(
301+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
302+
) -> Dict:
303+
"""Construct and return the headers for a request.
304+
305+
Args:
306+
is_async (bool, optional): Indicates if the request is asynchronous.
307+
Defaults to `False`.
308+
body (optional): The request body to be included in the headers if
309+
the request is asynchronous.
310+
311+
Returns:
312+
Dict: A dictionary containing the appropriate headers for the request.
313+
"""
314+
return {
315+
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
316+
**super()._headers(is_async=is_async, body=body),
317+
}
318+
303319
def _generate(
304320
self,
305321
messages: List[BaseMessage],

ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class BaseOCIModelDeployment(Serializable):
8585
max_retries: int = 3
8686
"""Maximum number of retries to make when generating."""
8787

88-
headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT}
88+
default_headers: Optional[Dict[str, Any]] = None
8989
"""The headers to be added to the Model Deployment request."""
9090

9191
@model_validator(mode="before")
@@ -127,7 +127,7 @@ def _headers(
127127
Returns:
128128
Dict: A dictionary containing the appropriate headers for the request.
129129
"""
130-
headers = self.headers
130+
headers = self.default_headers or {}
131131
if is_async:
132132
signer = self.auth["signer"]
133133
_req = requests.Request("POST", self.endpoint, json=body)
@@ -485,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]:
485485
**self._default_params,
486486
}
487487

488+
def _headers(
489+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
490+
) -> Dict:
491+
"""Construct and return the headers for a request.
492+
493+
Args:
494+
is_async (bool, optional): Indicates if the request is asynchronous.
495+
Defaults to `False`.
496+
body (optional): The request body to be included in the headers if
497+
the request is asynchronous.
498+
499+
Returns:
500+
Dict: A dictionary containing the appropriate headers for the request.
501+
"""
502+
return {
503+
"route": DEFAULT_INFERENCE_ENDPOINT,
504+
**super()._headers(is_async=is_async, body=body),
505+
}
506+
488507
def _generate(
489508
self,
490509
prompts: List[str],

tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
2727
CONST_PROMPT = "This is a prompt."
2828
CONST_COMPLETION = "This is a completion."
29-
CONST_ENDPOINT = "/v1/chat/completions"
29+
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
3030
CONST_COMPLETION_RESPONSE = {
3131
"id": "chat-123456789",
3232
"object": "chat.completion",
@@ -124,7 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
124124
def test_invoke_vllm(*args: Any) -> None:
125125
"""Tests invoking vLLM endpoint."""
126126
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
127-
assert llm.headers == {"route": CONST_ENDPOINT}
127+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
128128
output = llm.invoke(CONST_PROMPT)
129129
assert isinstance(output, AIMessage)
130130
assert output.content == CONST_COMPLETION
@@ -137,7 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
137137
def test_invoke_tgi(*args: Any) -> None:
138138
"""Tests invoking TGI endpoint using OpenAI Spec."""
139139
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
140-
assert llm.headers == {"route": CONST_ENDPOINT}
140+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
141141
output = llm.invoke(CONST_PROMPT)
142142
assert isinstance(output, AIMessage)
143143
assert output.content == CONST_COMPLETION
@@ -152,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
152152
llm = ChatOCIModelDeploymentVLLM(
153153
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
154154
)
155-
assert llm.headers == {"route": CONST_ENDPOINT}
155+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
156156
output = None
157157
count = 0
158158
for chunk in llm.stream(CONST_PROMPT):
@@ -191,7 +191,7 @@ async def test_stream_async(*args: Any) -> None:
191191
llm = ChatOCIModelDeploymentVLLM(
192192
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
193193
)
194-
assert llm.headers == {"route": CONST_ENDPOINT}
194+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
195195
with mock.patch.object(
196196
llm,
197197
"_aiter_sse",

tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
2525
CONST_PROMPT = "This is a prompt."
2626
CONST_COMPLETION = "This is a completion."
27-
CONST_ENDPOINT = "/v1/completions"
27+
CONST_COMPLETION_ROUTE = "/v1/completions"
2828
CONST_COMPLETION_RESPONSE = {
2929
"choices": [
3030
{
@@ -117,7 +117,7 @@ async def mocked_async_streaming_response(
117117
def test_invoke_vllm(*args: Any) -> None:
118118
"""Tests invoking vLLM endpoint."""
119119
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
120-
assert llm.headers == {"route": CONST_ENDPOINT}
120+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
121121
output = llm.invoke(CONST_PROMPT)
122122
assert output == CONST_COMPLETION
123123

@@ -130,7 +130,7 @@ def test_stream_tgi(*args: Any) -> None:
130130
llm = OCIModelDeploymentTGI(
131131
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
132132
)
133-
assert llm.headers == {"route": CONST_ENDPOINT}
133+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
134134
output = ""
135135
count = 0
136136
for chunk in llm.stream(CONST_PROMPT):
@@ -148,7 +148,7 @@ def test_generate_tgi(*args: Any) -> None:
148148
llm = OCIModelDeploymentTGI(
149149
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
150150
)
151-
assert llm.headers == {"route": CONST_ENDPOINT}
151+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
152152
output = llm.invoke(CONST_PROMPT)
153153
assert output == CONST_COMPLETION
154154

@@ -167,7 +167,7 @@ async def test_stream_async(*args: Any) -> None:
167167
llm = OCIModelDeploymentTGI(
168168
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
169169
)
170-
assert llm.headers == {"route": CONST_ENDPOINT}
170+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
171171
with mock.patch.object(
172172
llm,
173173
"_aiter_sse",

0 commit comments

Comments
 (0)