Skip to content

Commit c5bb94e

Browse files
authored
Updates on LLM integration and documentation (#982)
2 parents c10e309 + d16ca73 commit c5bb94e

File tree

12 files changed

+413
-255
lines changed

12 files changed

+413
-255
lines changed

ads/llm/deploy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919

2020

2121
class ChainDeployment(GenericModel):
22+
"""Represents a model deployment with LangChain.
23+
"""
2224
def __init__(self, chain, **kwargs):
2325
self.chain = chain
26+
if "model_input_serializer" not in kwargs:
27+
kwargs["model_input_serializer"] = self.model_input_serializer_type.JSON
2428
super().__init__(**kwargs)
2529

2630
def prepare(self, **kwargs) -> GenericModel:
2731
"""Prepares the model artifact."""
2832
chain_yaml_uri = os.path.join(self.artifact_dir, "chain.yaml")
33+
if not os.path.exists(self.artifact_dir):
34+
os.makedirs(self.artifact_dir)
2935
with open(chain_yaml_uri, "w", encoding="utf-8") as f:
3036
f.write(yaml.safe_dump(dump(self.chain)))
3137

ads/llm/guardrails/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ class Guardrail(BaseTool):
156156

157157
class Config:
158158
arbitrary_types_allowed = True
159-
underscore_attrs_are_private = True
160159

161160
name: str = ""
162161
description: str = "Guardrail"

ads/llm/langchain/plugins/chat_models/oci_data_science.py

Lines changed: 118 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33

44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
"""Chat model for OCI data science model deployment endpoint."""
67

7-
8+
import importlib
89
import json
910
import logging
1011
from operator import itemgetter
1112
from typing import (
1213
Any,
1314
AsyncIterator,
15+
Callable,
1416
Dict,
1517
Iterator,
1618
List,
1719
Literal,
1820
Optional,
21+
Sequence,
1922
Type,
2023
Union,
21-
Sequence,
22-
Callable,
2324
)
2425

2526
from langchain_core.callbacks import (
@@ -33,21 +34,16 @@
3334
generate_from_stream,
3435
)
3536
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
36-
from langchain_core.tools import BaseTool
3737
from langchain_core.output_parsers import (
3838
JsonOutputParser,
3939
PydanticOutputParser,
4040
)
4141
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
4242
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
43+
from langchain_core.tools import BaseTool
4344
from langchain_core.utils.function_calling import convert_to_openai_tool
44-
from langchain_openai.chat_models.base import (
45-
_convert_delta_to_message_chunk,
46-
_convert_message_to_dict,
47-
_convert_dict_to_message,
48-
)
45+
from pydantic import BaseModel, Field, model_validator
4946

50-
from pydantic import BaseModel, Field
5147
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
5248
DEFAULT_MODEL_NAME,
5349
BaseOCIModelDeployment,
@@ -63,23 +59,48 @@ def _is_pydantic_class(obj: Any) -> bool:
6359
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
6460
"""OCI Data Science Model Deployment chat model integration.
6561
66-
To use, you must provide the model HTTP endpoint from your deployed
67-
chat model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
62+
Setup:
63+
Install ``oracle-ads`` and ``langchain-openai``.
6864
69-
To authenticate, `oracle-ads` has been used to automatically load
70-
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
65+
.. code-block:: bash
7166
72-
Make sure to have the required policies to access the OCI Data
73-
Science Model Deployment endpoint. See:
74-
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
67+
pip install -U oracle-ads langchain-openai
68+
69+
Use `ads.set_auth()` to configure authentication.
70+
For example, to use OCI resource_principal for authentication:
71+
72+
.. code-block:: python
73+
74+
import ads
75+
ads.set_auth("resource_principal")
76+
77+
For more details on authentication, see:
78+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
79+
80+
Make sure to have the required policies to access the OCI Data
81+
Science Model Deployment endpoint. See:
82+
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm
83+
84+
85+
Key init args - completion params:
86+
endpoint: str
87+
The OCI model deployment endpoint.
88+
temperature: float
89+
Sampling temperature.
90+
max_tokens: Optional[int]
91+
Max number of tokens to generate.
92+
93+
Key init args — client params:
94+
auth: dict
95+
ADS auth dictionary for OCI authentication.
7596
7697
Instantiate:
7798
.. code-block:: python
7899
79100
from langchain_community.chat_models import ChatOCIModelDeployment
80101
81102
chat = ChatOCIModelDeployment(
82-
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
103+
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
83104
model="odsc-llm",
84105
streaming=True,
85106
max_retries=3,
@@ -94,15 +115,27 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
94115
.. code-block:: python
95116
96117
messages = [
97-
("system", "You are a helpful translator. Translate the user sentence to French."),
118+
("system", "Translate the user sentence to French."),
98119
("human", "Hello World!"),
99120
]
100121
chat.invoke(messages)
101122
102123
.. code-block:: python
103124
104125
AIMessage(
105-
content='Bonjour le monde!',response_metadata={'token_usage': {'prompt_tokens': 40, 'total_tokens': 50, 'completion_tokens': 10},'model_name': 'odsc-llm','system_fingerprint': '','finish_reason': 'stop'},id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0')
126+
content='Bonjour le monde!',
127+
response_metadata={
128+
'token_usage': {
129+
'prompt_tokens': 40,
130+
'total_tokens': 50,
131+
'completion_tokens': 10
132+
},
133+
'model_name': 'odsc-llm',
134+
'system_fingerprint': '',
135+
'finish_reason': 'stop'
136+
},
137+
id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0'
138+
)
106139
107140
Streaming:
108141
.. code-block:: python
@@ -112,18 +145,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
112145
113146
.. code-block:: python
114147
115-
content='' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
116-
content='\n' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
117-
content='B' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
118-
content='on' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
119-
content='j' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
120-
content='our' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
121-
content=' le' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
122-
content=' monde' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
123-
content='!' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
124-
content='' response_metadata={'finish_reason': 'stop'} id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
125-
126-
Asyc:
148+
content='' id='run-02c6-c43f-42de'
149+
content='\n' id='run-02c6-c43f-42de'
150+
content='B' id='run-02c6-c43f-42de'
151+
content='on' id='run-02c6-c43f-42de'
152+
content='j' id='run-02c6-c43f-42de'
153+
content='our' id='run-02c6-c43f-42de'
154+
content=' le' id='run-02c6-c43f-42de'
155+
content=' monde' id='run-02c6-c43f-42de'
156+
content='!' id='run-02c6-c43f-42de'
157+
content='' response_metadata={'finish_reason': 'stop'} id='run-02c6-c43f-42de'
158+
159+
Async:
127160
.. code-block:: python
128161
129162
await chat.ainvoke(messages)
@@ -133,7 +166,11 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
133166
134167
.. code-block:: python
135168
136-
AIMessage(content='Bonjour le monde!', response_metadata={'finish_reason': 'stop'}, id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0')
169+
AIMessage(
170+
content='Bonjour le monde!',
171+
response_metadata={'finish_reason': 'stop'},
172+
id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0'
173+
)
137174
138175
Structured output:
139176
.. code-block:: python
@@ -147,19 +184,22 @@ class Joke(BaseModel):
147184
148185
structured_llm = chat.with_structured_output(Joke, method="json_mode")
149186
structured_llm.invoke(
150-
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
187+
"Tell me a joke about cats, "
188+
"respond in JSON with `setup` and `punchline` keys"
151189
)
152190
153191
.. code-block:: python
154192
155-
Joke(setup='Why did the cat get stuck in the tree?',punchline='Because it was chasing its tail!')
193+
Joke(
194+
setup='Why did the cat get stuck in the tree?',
195+
punchline='Because it was chasing its tail!'
196+
)
156197
157198
See ``ChatOCIModelDeployment.with_structured_output()`` for more.
158199
159200
Customized Usage:
160-
161-
You can inherit from base class and overwrite the `_process_response`, `_process_stream_response`,
162-
`_construct_json_body` for satisfying customized needed.
201+
You can inherit from base class and overwrite the `_process_response`,
202+
`_process_stream_response`, `_construct_json_body` for customized usage.
163203
164204
.. code-block:: python
165205
@@ -180,12 +220,31 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
180220
}
181221
182222
chat = MyChatModel(
183-
endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict",
223+
endpoint=f"https://modeldeployment.<region>.oci.customer-oci.com/{ocid}/predict",
184224
model="odsc-llm",
185225
}
186226
187227
chat.invoke("tell me a joke")
188228
229+
Response metadata
230+
.. code-block:: python
231+
232+
ai_msg = chat.invoke(messages)
233+
ai_msg.response_metadata
234+
235+
.. code-block:: python
236+
237+
{
238+
'token_usage': {
239+
'prompt_tokens': 40,
240+
'total_tokens': 50,
241+
'completion_tokens': 10
242+
},
243+
'model_name': 'odsc-llm',
244+
'system_fingerprint': '',
245+
'finish_reason': 'stop'
246+
}
247+
189248
""" # noqa: E501
190249

191250
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -198,6 +257,17 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
198257
"""Stop words to use when generating. Model output is cut off
199258
at the first occurrence of any of these substrings."""
200259

260+
@model_validator(mode="before")
261+
@classmethod
262+
def validate_openai(cls, values: Any) -> Any:
263+
"""Checks if langchain_openai is installed."""
264+
if not importlib.util.find_spec("langchain_openai"):
265+
raise ImportError(
266+
"Could not import langchain_openai package. "
267+
"Please install it with `pip install langchain_openai`."
268+
)
269+
return values
270+
201271
@property
202272
def _llm_type(self) -> str:
203273
"""Return type of llm."""
@@ -552,6 +622,8 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
552622
converted messages and additional parameters.
553623
554624
"""
625+
from langchain_openai.chat_models.base import _convert_message_to_dict
626+
555627
return {
556628
"messages": [_convert_message_to_dict(m) for m in messages],
557629
**params,
@@ -578,6 +650,8 @@ def _process_stream_response(
578650
ValueError: If the response JSON is not well-formed or does not
579651
contain the expected structure.
580652
"""
653+
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
654+
581655
try:
582656
choice = response_json["choices"][0]
583657
if not isinstance(choice, dict):
@@ -616,6 +690,8 @@ def _process_response(self, response_json: dict) -> ChatResult:
616690
contain the expected structure.
617691
618692
"""
693+
from langchain_openai.chat_models.base import _convert_dict_to_message
694+
619695
generations = []
620696
try:
621697
choices = response_json["choices"]
@@ -760,8 +836,9 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
760836
tool_choice: Optional[str] = None
761837
"""Whether to use tool calling.
762838
Defaults to None, tool calling is disabled.
763-
Tool calling requires model support and vLLM to be configured with `--tool-call-parser`.
764-
Set this to `auto` for the model to determine whether to make tool calls automatically.
839+
Tool calling requires model support and the vLLM to be configured
840+
with `--tool-call-parser`.
841+
Set this to `auto` for the model to make tool calls automatically.
765842
Set this to `required` to force the model to always call one or more tools.
766843
"""
767844

0 commit comments

Comments
 (0)