Skip to content

Commit de214d3

Browse files
authored
Added headers parameters for ADS Langchain (#1020)
2 parents eab7c47 + 3cfccc2 commit de214d3

File tree

6 files changed

+97
-34
lines changed

6 files changed

+97
-34
lines changed

.github/workflows/run-unittests-py39-py310.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ jobs:
7474
name: "Test env setup"
7575
timeout-minutes: 30
7676

77-
- name: "Run hpo tests"
78-
timeout-minutes: 10
79-
shell: bash
80-
if: ${{ matrix.name }} == "unitary"
81-
run: |
82-
set -x # print commands that are executed
77+
# - name: "Run hpo tests"
78+
# timeout-minutes: 10
79+
# shell: bash
80+
# if: ${{ matrix.name }} == "unitary"
81+
# run: |
82+
# set -x # print commands that are executed
8383

84-
# Run hpo tests, which hangs if run together with all unitary tests
85-
python -m pytest -v -p no:warnings -n auto --dist loadfile \
86-
tests/unitary/with_extras/hpo
84+
# # Run hpo tests, which hangs if run together with all unitary tests
85+
# python -m pytest -v -p no:warnings -n auto --dist loadfile \
86+
# tests/unitary/with_extras/hpo
8787

8888
- name: "Run unitary tests folder with maximum ADS dependencies"
8989
timeout-minutes: 60

ads/llm/guardrails/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

87
import datetime
98
import functools
10-
import operator
119
import importlib.util
10+
import operator
1211
import sys
12+
from typing import Any, List, Optional, Union
1313

14-
from typing import Any, List, Dict, Tuple
1514
from langchain.schema.prompt import PromptValue
1615
from langchain.tools.base import BaseTool, ToolException
1716
from pydantic import BaseModel, model_validator
@@ -207,7 +206,9 @@ def _preprocess(self, input: Any) -> str:
207206
return input.to_string()
208207
return str(input)
209208

210-
def _to_args_and_kwargs(self, tool_input: Any) -> Tuple[Tuple, Dict]:
209+
def _to_args_and_kwargs(
210+
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
211+
) -> tuple[tuple, dict]:
211212
if isinstance(tool_input, dict):
212213
return (), tool_input
213214
else:

ads/llm/langchain/plugins/chat_models/oci_data_science.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65
"""Chat model for OCI data science model deployment endpoint."""
76

@@ -50,6 +49,7 @@
5049
)
5150

5251
logger = logging.getLogger(__name__)
52+
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
5353

5454

5555
def _is_pydantic_class(obj: Any) -> bool:
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
9393
Key init args — client params:
9494
auth: dict
9595
ADS auth dictionary for OCI authentication.
96+
default_headers: Optional[Dict]
97+
The headers to be added to the Model Deployment request.
9698
9799
Instantiate:
98100
.. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109111
"temperature": 0.2,
110112
# other model parameters ...
111113
},
114+
default_headers={
115+
"route": "/v1/chat/completions",
116+
# other request headers ...
117+
},
112118
)
113119
114120
Invocation:
@@ -291,6 +297,25 @@ def _default_params(self) -> Dict[str, Any]:
291297
"stream": self.streaming,
292298
}
293299

300+
def _headers(
301+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
302+
) -> Dict:
303+
"""Construct and return the headers for a request.
304+
305+
Args:
306+
is_async (bool, optional): Indicates if the request is asynchronous.
307+
Defaults to `False`.
308+
body (optional): The request body to be included in the headers if
309+
the request is asynchronous.
310+
311+
Returns:
312+
Dict: A dictionary containing the appropriate headers for the request.
313+
"""
314+
return {
315+
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
316+
**super()._headers(is_async=is_async, body=body),
317+
}
318+
294319
def _generate(
295320
self,
296321
messages: List[BaseMessage],
@@ -704,7 +729,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
704729

705730
for choice in choices:
706731
message = _convert_dict_to_message(choice["message"])
707-
generation_info = dict(finish_reason=choice.get("finish_reason"))
732+
generation_info = {"finish_reason": choice.get("finish_reason")}
708733
if "logprobs" in choice:
709734
generation_info["logprobs"] = choice["logprobs"]
710735

@@ -794,7 +819,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
794819
"""Number of most likely tokens to consider at each step."""
795820

796821
min_p: Optional[float] = 0.0
797-
"""Float that represents the minimum probability for a token to be considered.
822+
"""Float that represents the minimum probability for a token to be considered.
798823
Must be in [0,1]. 0 to disable this."""
799824

800825
repetition_penalty: Optional[float] = 1.0
@@ -818,7 +843,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
818843
the EOS token is generated."""
819844

820845
min_tokens: Optional[int] = 0
821-
"""Minimum number of tokens to generate per output sequence before
846+
"""Minimum number of tokens to generate per output sequence before
822847
EOS or stop_token_ids can be generated"""
823848

824849
stop_token_ids: Optional[List[int]] = None
@@ -836,7 +861,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
836861
tool_choice: Optional[str] = None
837862
"""Whether to use tool calling.
838863
Defaults to None, tool calling is disabled.
839-
Tool calling requires model support and the vLLM to be configured
864+
Tool calling requires model support and the vLLM to be configured
840865
with `--tool-call-parser`.
841866
Set this to `auto` for the model to make tool calls automatically.
842867
Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +981,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956981
"""Total probability mass of tokens to consider at each step."""
957982

958983
top_logprobs: Optional[int] = None
959-
"""An integer between 0 and 5 specifying the number of most
960-
likely tokens to return at each token position, each with an
961-
associated log probability. logprobs must be set to true if
984+
"""An integer between 0 and 5 specifying the number of most
985+
likely tokens to return at each token position, each with an
986+
associated log probability. logprobs must be set to true if
962987
this parameter is used."""
963988

964989
@property

ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

@@ -24,6 +23,7 @@
2423

2524
import aiohttp
2625
import requests
26+
from langchain_community.utilities.requests import Requests
2727
from langchain_core.callbacks import (
2828
AsyncCallbackManagerForLLMRun,
2929
CallbackManagerForLLMRun,
@@ -34,14 +34,13 @@
3434
from langchain_core.utils import get_from_dict_or_env
3535
from pydantic import Field, model_validator
3636

37-
from langchain_community.utilities.requests import Requests
38-
3937
logger = logging.getLogger(__name__)
4038

4139

4240
DEFAULT_TIME_OUT = 300
4341
DEFAULT_CONTENT_TYPE_JSON = "application/json"
4442
DEFAULT_MODEL_NAME = "odsc-llm"
43+
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
4544

4645

4746
class TokenExpiredError(Exception):
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
8685
max_retries: int = 3
8786
"""Maximum number of retries to make when generating."""
8887

88+
default_headers: Optional[Dict[str, Any]] = None
89+
"""The headers to be added to the Model Deployment request."""
90+
8991
@model_validator(mode="before")
9092
@classmethod
9193
def validate_environment(cls, values: Dict) -> Dict:
@@ -101,7 +103,7 @@ def validate_environment(cls, values: Dict) -> Dict:
101103
"Please install it with `pip install oracle_ads`."
102104
) from ex
103105

104-
if not values.get("auth", None):
106+
if not values.get("auth"):
105107
values["auth"] = ads.common.auth.default_signer()
106108

107109
values["endpoint"] = get_from_dict_or_env(
@@ -125,12 +127,12 @@ def _headers(
125127
Returns:
126128
Dict: A dictionary containing the appropriate headers for the request.
127129
"""
130+
headers = self.default_headers or {}
128131
if is_async:
129132
signer = self.auth["signer"]
130133
_req = requests.Request("POST", self.endpoint, json=body)
131134
req = _req.prepare()
132135
req = signer(req)
133-
headers = {}
134136
for key, value in req.headers.items():
135137
headers[key] = value
136138

@@ -140,7 +142,7 @@ def _headers(
140142
)
141143
return headers
142144

143-
return (
145+
headers.update(
144146
{
145147
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
146148
"enable-streaming": "true",
@@ -152,6 +154,8 @@ def _headers(
152154
}
153155
)
154156

157+
return headers
158+
155159
def completion_with_retry(
156160
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
157161
) -> Any:
@@ -357,7 +361,7 @@ def _refresh_signer(self) -> bool:
357361
self.auth["signer"].refresh_security_token()
358362
return True
359363
return False
360-
364+
361365
@classmethod
362366
def is_lc_serializable(cls) -> bool:
363367
"""Return whether this model can be serialized by LangChain."""
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
388392
model="odsc-llm",
389393
streaming=True,
390394
model_kwargs={"frequency_penalty": 1.0},
395+
headers={
396+
"route": "/v1/completions",
397+
# other request headers ...
398+
}
391399
)
392400
llm.invoke("tell me a joke.")
393401
@@ -477,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]:
477485
**self._default_params,
478486
}
479487

488+
def _headers(
489+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
490+
) -> Dict:
491+
"""Construct and return the headers for a request.
492+
493+
Args:
494+
is_async (bool, optional): Indicates if the request is asynchronous.
495+
Defaults to `False`.
496+
body (optional): The request body to be included in the headers if
497+
the request is asynchronous.
498+
499+
Returns:
500+
Dict: A dictionary containing the appropriate headers for the request.
501+
"""
502+
return {
503+
"route": DEFAULT_INFERENCE_ENDPOINT,
504+
**super()._headers(is_async=is_async, body=body),
505+
}
506+
480507
def _generate(
481508
self,
482509
prompts: List[str],
@@ -712,9 +739,9 @@ def _process_response(self, response_json: dict) -> List[Generation]:
712739
def _generate_info(self, choice: dict) -> Any:
713740
"""Extracts generation info from the response."""
714741
gen_info = {}
715-
finish_reason = choice.get("finish_reason", None)
716-
logprobs = choice.get("logprobs", None)
717-
index = choice.get("index", None)
742+
finish_reason = choice.get("finish_reason")
743+
logprobs = choice.get("logprobs")
744+
index = choice.get("index")
718745
if finish_reason:
719746
gen_info.update({"finish_reason": finish_reason})
720747
if logprobs is not None:

tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
2727
CONST_PROMPT = "This is a prompt."
2828
CONST_COMPLETION = "This is a completion."
29+
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
2930
CONST_COMPLETION_RESPONSE = {
3031
"id": "chat-123456789",
3132
"object": "chat.completion",
@@ -123,6 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
123124
def test_invoke_vllm(*args: Any) -> None:
124125
"""Tests invoking vLLM endpoint."""
125126
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
127+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
126128
output = llm.invoke(CONST_PROMPT)
127129
assert isinstance(output, AIMessage)
128130
assert output.content == CONST_COMPLETION
@@ -135,6 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
135137
def test_invoke_tgi(*args: Any) -> None:
136138
"""Tests invoking TGI endpoint using OpenAI Spec."""
137139
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
140+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
138141
output = llm.invoke(CONST_PROMPT)
139142
assert isinstance(output, AIMessage)
140143
assert output.content == CONST_COMPLETION
@@ -149,6 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
149152
llm = ChatOCIModelDeploymentVLLM(
150153
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
151154
)
155+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
152156
output = None
153157
count = 0
154158
for chunk in llm.stream(CONST_PROMPT):
@@ -187,6 +191,7 @@ async def test_stream_async(*args: Any) -> None:
187191
llm = ChatOCIModelDeploymentVLLM(
188192
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
189193
)
194+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
190195
with mock.patch.object(
191196
llm,
192197
"_aiter_sse",

tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
2525
CONST_PROMPT = "This is a prompt."
2626
CONST_COMPLETION = "This is a completion."
27+
CONST_COMPLETION_ROUTE = "/v1/completions"
2728
CONST_COMPLETION_RESPONSE = {
2829
"choices": [
2930
{
@@ -116,6 +117,7 @@ async def mocked_async_streaming_response(
116117
def test_invoke_vllm(*args: Any) -> None:
117118
"""Tests invoking vLLM endpoint."""
118119
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
120+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
119121
output = llm.invoke(CONST_PROMPT)
120122
assert output == CONST_COMPLETION
121123

@@ -128,6 +130,7 @@ def test_stream_tgi(*args: Any) -> None:
128130
llm = OCIModelDeploymentTGI(
129131
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
130132
)
133+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
131134
output = ""
132135
count = 0
133136
for chunk in llm.stream(CONST_PROMPT):
@@ -145,6 +148,7 @@ def test_generate_tgi(*args: Any) -> None:
145148
llm = OCIModelDeploymentTGI(
146149
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
147150
)
151+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
148152
output = llm.invoke(CONST_PROMPT)
149153
assert output == CONST_COMPLETION
150154

@@ -163,6 +167,7 @@ async def test_stream_async(*args: Any) -> None:
163167
llm = OCIModelDeploymentTGI(
164168
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
165169
)
170+
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
166171
with mock.patch.object(
167172
llm,
168173
"_aiter_sse",

0 commit comments

Comments
 (0)