Skip to content

Commit 238b6d9

Browse files
committed
Added headers parameters
1 parent 3d8f148 commit 238b6d9

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

ads/llm/langchain/plugins/chat_models/oci_data_science.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65
"""Chat model for OCI data science model deployment endpoint."""
76

@@ -50,6 +49,7 @@
5049
)
5150

5251
logger = logging.getLogger(__name__)
52+
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
5353

5454

5555
def _is_pydantic_class(obj: Any) -> bool:
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
9393
Key init args — client params:
9494
auth: dict
9595
ADS auth dictionary for OCI authentication.
96+
headers: Optional[Dict]
97+
The headers to be added to the Model Deployment request.
9698
9799
Instantiate:
98100
.. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109111
"temperature": 0.2,
110112
# other model parameters ...
111113
},
114+
headers={
115+
"route": "/v1/chat/completions",
116+
# other request headers ...
117+
},
112118
)
113119
114120
Invocation:
@@ -257,6 +263,9 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
257263
"""Stop words to use when generating. Model output is cut off
258264
at the first occurrence of any of these substrings."""
259265

266+
headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
267+
"""The headers to be added to the Model Deployment request."""
268+
260269
@model_validator(mode="before")
261270
@classmethod
262271
def validate_openai(cls, values: Any) -> Any:
@@ -704,7 +713,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
704713

705714
for choice in choices:
706715
message = _convert_dict_to_message(choice["message"])
707-
generation_info = dict(finish_reason=choice.get("finish_reason"))
716+
generation_info = {"finish_reason": choice.get("finish_reason")}
708717
if "logprobs" in choice:
709718
generation_info["logprobs"] = choice["logprobs"]
710719

@@ -794,7 +803,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
794803
"""Number of most likely tokens to consider at each step."""
795804

796805
min_p: Optional[float] = 0.0
797-
"""Float that represents the minimum probability for a token to be considered.
806+
"""Float that represents the minimum probability for a token to be considered.
798807
Must be in [0,1]. 0 to disable this."""
799808

800809
repetition_penalty: Optional[float] = 1.0
@@ -818,7 +827,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
818827
the EOS token is generated."""
819828

820829
min_tokens: Optional[int] = 0
821-
"""Minimum number of tokens to generate per output sequence before
830+
"""Minimum number of tokens to generate per output sequence before
822831
EOS or stop_token_ids can be generated"""
823832

824833
stop_token_ids: Optional[List[int]] = None
@@ -836,7 +845,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
836845
tool_choice: Optional[str] = None
837846
"""Whether to use tool calling.
838847
Defaults to None, tool calling is disabled.
839-
Tool calling requires model support and the vLLM to be configured
848+
Tool calling requires model support and the vLLM to be configured
840849
with `--tool-call-parser`.
841850
Set this to `auto` for the model to make tool calls automatically.
842851
Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +965,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956965
"""Total probability mass of tokens to consider at each step."""
957966

958967
top_logprobs: Optional[int] = None
959-
"""An integer between 0 and 5 specifying the number of most
960-
likely tokens to return at each token position, each with an
961-
associated log probability. logprobs must be set to true if
968+
"""An integer between 0 and 5 specifying the number of most
969+
likely tokens to return at each token position, each with an
970+
associated log probability. logprobs must be set to true if
962971
this parameter is used."""
963972

964973
@property

ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

@@ -24,6 +23,7 @@
2423

2524
import aiohttp
2625
import requests
26+
from langchain_community.utilities.requests import Requests
2727
from langchain_core.callbacks import (
2828
AsyncCallbackManagerForLLMRun,
2929
CallbackManagerForLLMRun,
@@ -34,14 +34,13 @@
3434
from langchain_core.utils import get_from_dict_or_env
3535
from pydantic import Field, model_validator
3636

37-
from langchain_community.utilities.requests import Requests
38-
3937
logger = logging.getLogger(__name__)
4038

4139

4240
DEFAULT_TIME_OUT = 300
4341
DEFAULT_CONTENT_TYPE_JSON = "application/json"
4442
DEFAULT_MODEL_NAME = "odsc-llm"
43+
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
4544

4645

4746
class TokenExpiredError(Exception):
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
8685
max_retries: int = 3
8786
"""Maximum number of retries to make when generating."""
8887

88+
headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT}
89+
"""The headers to be added to the Model Deployment request."""
90+
8991
@model_validator(mode="before")
9092
@classmethod
9193
def validate_environment(cls, values: Dict) -> Dict:
@@ -101,7 +103,7 @@ def validate_environment(cls, values: Dict) -> Dict:
101103
"Please install it with `pip install oracle_ads`."
102104
) from ex
103105

104-
if not values.get("auth", None):
106+
if not values.get("auth"):
105107
values["auth"] = ads.common.auth.default_signer()
106108

107109
values["endpoint"] = get_from_dict_or_env(
@@ -125,12 +127,12 @@ def _headers(
125127
Returns:
126128
Dict: A dictionary containing the appropriate headers for the request.
127129
"""
130+
headers = self.headers
128131
if is_async:
129132
signer = self.auth["signer"]
130133
_req = requests.Request("POST", self.endpoint, json=body)
131134
req = _req.prepare()
132135
req = signer(req)
133-
headers = {}
134136
for key, value in req.headers.items():
135137
headers[key] = value
136138

@@ -140,7 +142,7 @@ def _headers(
140142
)
141143
return headers
142144

143-
return (
145+
headers.update(
144146
{
145147
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
146148
"enable-streaming": "true",
@@ -152,6 +154,8 @@ def _headers(
152154
}
153155
)
154156

157+
return headers
158+
155159
def completion_with_retry(
156160
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
157161
) -> Any:
@@ -357,7 +361,7 @@ def _refresh_signer(self) -> bool:
357361
self.auth["signer"].refresh_security_token()
358362
return True
359363
return False
360-
364+
361365
@classmethod
362366
def is_lc_serializable(cls) -> bool:
363367
"""Return whether this model can be serialized by LangChain."""
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
388392
model="odsc-llm",
389393
streaming=True,
390394
model_kwargs={"frequency_penalty": 1.0},
395+
headers={
396+
"route": "/v1/completions",
397+
# other request headers ...
398+
}
391399
)
392400
llm.invoke("tell me a joke.")
393401
@@ -712,9 +720,9 @@ def _process_response(self, response_json: dict) -> List[Generation]:
712720
def _generate_info(self, choice: dict) -> Any:
713721
"""Extracts generation info from the response."""
714722
gen_info = {}
715-
finish_reason = choice.get("finish_reason", None)
716-
logprobs = choice.get("logprobs", None)
717-
index = choice.get("index", None)
723+
finish_reason = choice.get("finish_reason")
724+
logprobs = choice.get("logprobs")
725+
index = choice.get("index")
718726
if finish_reason:
719727
gen_info.update({"finish_reason": finish_reason})
720728
if logprobs is not None:

0 commit comments

Comments
 (0)