1
1
#!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
"""Chat model for OCI data science model deployment endpoint."""
7
6
50
49
)
51
50
52
51
logger = logging .getLogger (__name__ )
52
+ DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
53
53
54
54
55
55
def _is_pydantic_class (obj : Any ) -> bool :
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
93
93
Key init args — client params:
94
94
auth: dict
95
95
ADS auth dictionary for OCI authentication.
96
+ headers: Optional[Dict]
97
+ The headers to be added to the Model Deployment request.
96
98
97
99
Instantiate:
98
100
.. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109
111
"temperature": 0.2,
110
112
# other model parameters ...
111
113
},
114
+ headers={
115
+ "route": "/v1/chat/completions",
116
+ # other request headers ...
117
+ },
112
118
)
113
119
114
120
Invocation:
@@ -257,6 +263,9 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
257
263
"""Stop words to use when generating. Model output is cut off
258
264
at the first occurrence of any of these substrings."""
259
265
266
+ headers : Optional [Dict [str , Any ]] = {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
267
+ """The headers to be added to the Model Deployment request."""
268
+
260
269
@model_validator (mode = "before" )
261
270
@classmethod
262
271
def validate_openai (cls , values : Any ) -> Any :
@@ -704,7 +713,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
704
713
705
714
for choice in choices :
706
715
message = _convert_dict_to_message (choice ["message" ])
707
- generation_info = dict ( finish_reason = choice .get ("finish_reason" ))
716
+ generation_info = { " finish_reason" : choice .get ("finish_reason" )}
708
717
if "logprobs" in choice :
709
718
generation_info ["logprobs" ] = choice ["logprobs" ]
710
719
@@ -794,7 +803,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
794
803
"""Number of most likely tokens to consider at each step."""
795
804
796
805
min_p : Optional [float ] = 0.0
797
- """Float that represents the minimum probability for a token to be considered.
806
+ """Float that represents the minimum probability for a token to be considered.
798
807
Must be in [0,1]. 0 to disable this."""
799
808
800
809
repetition_penalty : Optional [float ] = 1.0
@@ -818,7 +827,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
818
827
the EOS token is generated."""
819
828
820
829
min_tokens : Optional [int ] = 0
821
- """Minimum number of tokens to generate per output sequence before
830
+ """Minimum number of tokens to generate per output sequence before
822
831
EOS or stop_token_ids can be generated"""
823
832
824
833
stop_token_ids : Optional [List [int ]] = None
@@ -836,7 +845,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
836
845
tool_choice : Optional [str ] = None
837
846
"""Whether to use tool calling.
838
847
Defaults to None, tool calling is disabled.
839
- Tool calling requires model support and the vLLM to be configured
848
+ Tool calling requires model support and the vLLM to be configured
840
849
with `--tool-call-parser`.
841
850
Set this to `auto` for the model to make tool calls automatically.
842
851
Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +965,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956
965
"""Total probability mass of tokens to consider at each step."""
957
966
958
967
top_logprobs : Optional [int ] = None
959
- """An integer between 0 and 5 specifying the number of most
960
- likely tokens to return at each token position, each with an
961
- associated log probability. logprobs must be set to true if
968
+ """An integer between 0 and 5 specifying the number of most
969
+ likely tokens to return at each token position, each with an
970
+ associated log probability. logprobs must be set to true if
962
971
this parameter is used."""
963
972
964
973
@property
0 commit comments