3
3
4
4
# Copyright (c) 2023 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+ """Chat model for OCI data science model deployment endpoint."""
6
7
7
-
8
+ import importlib
8
9
import json
9
10
import logging
10
11
from operator import itemgetter
11
12
from typing import (
12
13
Any ,
13
14
AsyncIterator ,
15
+ Callable ,
14
16
Dict ,
15
17
Iterator ,
16
18
List ,
17
19
Literal ,
18
20
Optional ,
21
+ Sequence ,
19
22
Type ,
20
23
Union ,
21
- Sequence ,
22
- Callable ,
23
24
)
24
25
25
26
from langchain_core .callbacks import (
33
34
generate_from_stream ,
34
35
)
35
36
from langchain_core .messages import AIMessageChunk , BaseMessage , BaseMessageChunk
36
- from langchain_core .tools import BaseTool
37
37
from langchain_core .output_parsers import (
38
38
JsonOutputParser ,
39
39
PydanticOutputParser ,
40
40
)
41
41
from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , ChatResult
42
42
from langchain_core .runnables import Runnable , RunnableMap , RunnablePassthrough
43
+ from langchain_core .tools import BaseTool
43
44
from langchain_core .utils .function_calling import convert_to_openai_tool
44
- from langchain_openai .chat_models .base import (
45
- _convert_delta_to_message_chunk ,
46
- _convert_message_to_dict ,
47
- _convert_dict_to_message ,
48
- )
45
+ from pydantic import BaseModel , Field , model_validator
49
46
50
- from pydantic import BaseModel , Field
51
47
from ads .llm .langchain .plugins .llms .oci_data_science_model_deployment_endpoint import (
52
48
DEFAULT_MODEL_NAME ,
53
49
BaseOCIModelDeployment ,
@@ -63,23 +59,48 @@ def _is_pydantic_class(obj: Any) -> bool:
63
59
class ChatOCIModelDeployment (BaseChatModel , BaseOCIModelDeployment ):
64
60
"""OCI Data Science Model Deployment chat model integration.
65
61
66
- To use, you must provide the model HTTP endpoint from your deployed
67
- chat model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict .
62
+ Setup:
63
+ Install ``oracle-ads`` and ``langchain-openai`` .
68
64
69
- To authenticate, `oracle-ads` has been used to automatically load
70
- credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
65
+ .. code-block:: bash
71
66
72
- Make sure to have the required policies to access the OCI Data
73
- Science Model Deployment endpoint. See:
74
- https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
67
+ pip install -U oracle-ads langchain-openai
68
+
69
+ Use `ads.set_auth()` to configure authentication.
70
+ For example, to use OCI resource_principal for authentication:
71
+
72
+ .. code-block:: python
73
+
74
+ import ads
75
+ ads.set_auth("resource_principal")
76
+
77
+ For more details on authentication, see:
78
+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
79
+
80
+ Make sure to have the required policies to access the OCI Data
81
+ Science Model Deployment endpoint. See:
82
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm
83
+
84
+
85
+ Key init args - completion params:
86
+ endpoint: str
87
+ The OCI model deployment endpoint.
88
+ temperature: float
89
+ Sampling temperature.
90
+ max_tokens: Optional[int]
91
+ Max number of tokens to generate.
92
+
93
+ Key init args — client params:
94
+ auth: dict
95
+ ADS auth dictionary for OCI authentication.
75
96
76
97
Instantiate:
77
98
.. code-block:: python
78
99
79
100
from langchain_community.chat_models import ChatOCIModelDeployment
80
101
81
102
chat = ChatOCIModelDeployment(
82
- endpoint="https://modeldeployment.us-ashburn-1 .oci.customer-oci.com/<ocid>/predict",
103
+ endpoint="https://modeldeployment.<region> .oci.customer-oci.com/<ocid>/predict",
83
104
model="odsc-llm",
84
105
streaming=True,
85
106
max_retries=3,
@@ -94,15 +115,27 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
94
115
.. code-block:: python
95
116
96
117
messages = [
97
- ("system", "You are a helpful translator. Translate the user sentence to French."),
118
+ ("system", "Translate the user sentence to French."),
98
119
("human", "Hello World!"),
99
120
]
100
121
chat.invoke(messages)
101
122
102
123
.. code-block:: python
103
124
104
125
AIMessage(
105
- content='Bonjour le monde!',response_metadata={'token_usage': {'prompt_tokens': 40, 'total_tokens': 50, 'completion_tokens': 10},'model_name': 'odsc-llm','system_fingerprint': '','finish_reason': 'stop'},id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0')
126
+ content='Bonjour le monde!',
127
+ response_metadata={
128
+ 'token_usage': {
129
+ 'prompt_tokens': 40,
130
+ 'total_tokens': 50,
131
+ 'completion_tokens': 10
132
+ },
133
+ 'model_name': 'odsc-llm',
134
+ 'system_fingerprint': '',
135
+ 'finish_reason': 'stop'
136
+ },
137
+ id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0'
138
+ )
106
139
107
140
Streaming:
108
141
.. code-block:: python
@@ -112,18 +145,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
112
145
113
146
.. code-block:: python
114
147
115
- content='' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
116
- content='\n ' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
117
- content='B' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
118
- content='on' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
119
- content='j' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
120
- content='our' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
121
- content=' le' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
122
- content=' monde' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
123
- content='!' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
124
- content='' response_metadata={'finish_reason': 'stop'} id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
125
-
126
- Asyc :
148
+ content='' id='run-02c6 -c43f-42de'
149
+ content='\n ' id='run-02c6 -c43f-42de'
150
+ content='B' id='run-02c6 -c43f-42de'
151
+ content='on' id='run-02c6 -c43f-42de'
152
+ content='j' id='run-02c6 -c43f-42de'
153
+ content='our' id='run-02c6 -c43f-42de'
154
+ content=' le' id='run-02c6 -c43f-42de'
155
+ content=' monde' id='run-02c6 -c43f-42de'
156
+ content='!' id='run-02c6 -c43f-42de'
157
+ content='' response_metadata={'finish_reason': 'stop'} id='run-02c6 -c43f-42de'
158
+
159
+ Async :
127
160
.. code-block:: python
128
161
129
162
await chat.ainvoke(messages)
@@ -133,7 +166,11 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
133
166
134
167
.. code-block:: python
135
168
136
- AIMessage(content='Bonjour le monde!', response_metadata={'finish_reason': 'stop'}, id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0')
169
+ AIMessage(
170
+ content='Bonjour le monde!',
171
+ response_metadata={'finish_reason': 'stop'},
172
+ id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0'
173
+ )
137
174
138
175
Structured output:
139
176
.. code-block:: python
@@ -147,19 +184,22 @@ class Joke(BaseModel):
147
184
148
185
structured_llm = chat.with_structured_output(Joke, method="json_mode")
149
186
structured_llm.invoke(
150
- "Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
187
+ "Tell me a joke about cats, "
188
+ "respond in JSON with `setup` and `punchline` keys"
151
189
)
152
190
153
191
.. code-block:: python
154
192
155
- Joke(setup='Why did the cat get stuck in the tree?',punchline='Because it was chasing its tail!')
193
+ Joke(
194
+ setup='Why did the cat get stuck in the tree?',
195
+ punchline='Because it was chasing its tail!'
196
+ )
156
197
157
198
See ``ChatOCIModelDeployment.with_structured_output()`` for more.
158
199
159
200
Customized Usage:
160
-
161
- You can inherit from base class and overwrite the `_process_response`, `_process_stream_response`,
162
- `_construct_json_body` for satisfying customized needed.
201
+ You can inherit from base class and overwrite the `_process_response`,
202
+ `_process_stream_response`, `_construct_json_body` for customized usage.
163
203
164
204
.. code-block:: python
165
205
@@ -180,12 +220,31 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
180
220
}
181
221
182
222
chat = MyChatModel(
183
- endpoint=f"https://modeldeployment.us-ashburn-1 .oci.customer-oci.com/{ocid}/predict",
223
+ endpoint=f"https://modeldeployment.<region> .oci.customer-oci.com/{ocid}/predict",
184
224
model="odsc-llm",
185
225
}
186
226
187
227
chat.invoke("tell me a joke")
188
228
229
+ Response metadata
230
+ .. code-block:: python
231
+
232
+ ai_msg = chat.invoke(messages)
233
+ ai_msg.response_metadata
234
+
235
+ .. code-block:: python
236
+
237
+ {
238
+ 'token_usage': {
239
+ 'prompt_tokens': 40,
240
+ 'total_tokens': 50,
241
+ 'completion_tokens': 10
242
+ },
243
+ 'model_name': 'odsc-llm',
244
+ 'system_fingerprint': '',
245
+ 'finish_reason': 'stop'
246
+ }
247
+
189
248
""" # noqa: E501
190
249
191
250
model_kwargs : Dict [str , Any ] = Field (default_factory = dict )
@@ -198,6 +257,17 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
198
257
"""Stop words to use when generating. Model output is cut off
199
258
at the first occurrence of any of these substrings."""
200
259
260
+ @model_validator (mode = "before" )
261
+ @classmethod
262
+ def validate_openai (cls , values : Any ) -> Any :
263
+ """Checks if langchain_openai is installed."""
264
+ if not importlib .util .find_spec ("langchain_openai" ):
265
+ raise ImportError (
266
+ "Could not import langchain_openai package. "
267
+ "Please install it with `pip install langchain_openai`."
268
+ )
269
+ return values
270
+
201
271
@property
202
272
def _llm_type (self ) -> str :
203
273
"""Return type of llm."""
@@ -552,6 +622,8 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
552
622
converted messages and additional parameters.
553
623
554
624
"""
625
+ from langchain_openai .chat_models .base import _convert_message_to_dict
626
+
555
627
return {
556
628
"messages" : [_convert_message_to_dict (m ) for m in messages ],
557
629
** params ,
@@ -578,6 +650,8 @@ def _process_stream_response(
578
650
ValueError: If the response JSON is not well-formed or does not
579
651
contain the expected structure.
580
652
"""
653
+ from langchain_openai .chat_models .base import _convert_delta_to_message_chunk
654
+
581
655
try :
582
656
choice = response_json ["choices" ][0 ]
583
657
if not isinstance (choice , dict ):
@@ -616,6 +690,8 @@ def _process_response(self, response_json: dict) -> ChatResult:
616
690
contain the expected structure.
617
691
618
692
"""
693
+ from langchain_openai .chat_models .base import _convert_dict_to_message
694
+
619
695
generations = []
620
696
try :
621
697
choices = response_json ["choices" ]
@@ -760,8 +836,9 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
760
836
tool_choice : Optional [str ] = None
761
837
"""Whether to use tool calling.
762
838
Defaults to None, tool calling is disabled.
763
- Tool calling requires model support and vLLM to be configured with `--tool-call-parser`.
764
- Set this to `auto` for the model to determine whether to make tool calls automatically.
839
+ Tool calling requires model support and the vLLM to be configured
840
+ with `--tool-call-parser`.
841
+ Set this to `auto` for the model to make tool calls automatically.
765
842
Set this to `required` to force the model to always call one or more tools.
766
843
"""
767
844
0 commit comments