Skip to content

Commit d1b0ede

Browse files
authored
Update Generative AI LLM model to use new APIs. (#523)
2 parents 9968a51 + b1f1f42 commit d1b0ede

File tree

7 files changed

+135
-72
lines changed

7 files changed

+135
-72
lines changed

ads/llm/langchain/plugins/base.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
from typing import Any, Dict, List, Optional
77

@@ -77,27 +77,32 @@ class GenerativeAiClientModel(BaseModel):
7777
client_kwargs: Dict[str, Any] = {}
7878
"""Holds any client parameters for creating GenerativeAiClient"""
7979

80-
@root_validator()
81-
def validate_environment( # pylint: disable=no-self-argument
82-
cls, values: Dict
83-
) -> Dict:
84-
"""Validate that python package exists in environment."""
80+
@staticmethod
81+
def _import_client():
8582
try:
86-
# Import the GenerativeAIClient here so that there will be no error when user import ads.llm
87-
# and the install OCI SDK does not support generative AI service yet.
88-
from oci.generative_ai import GenerativeAiClient
83+
from oci.generative_ai_inference import GenerativeAiInferenceClient
8984
except ImportError as ex:
9085
raise ImportError(
91-
"Could not import GenerativeAIClient from oci. "
86+
"Could not import GenerativeAiInferenceClient from oci. "
9287
"The OCI SDK installed does not support generative AI service."
9388
) from ex
89+
return GenerativeAiInferenceClient
90+
91+
@root_validator()
92+
def validate_environment( # pylint: disable=no-self-argument
93+
cls, values: Dict
94+
) -> Dict:
95+
"""Validate that python package exists in environment."""
9496
# Initialize client only if user does not pass in client.
9597
# Users may choose to initialize the OCI client by themselves and pass it into this model.
9698
if not values.get("client"):
9799
auth = values.get("auth", {})
98100
client_kwargs = auth.get("client_kwargs") or {}
99101
client_kwargs.update(values["client_kwargs"])
100-
values["client"] = GenerativeAiClient(**auth, **client_kwargs)
102+
# Import the GenerativeAIClient here so that there will be no error when user import ads.llm
103+
# and the install OCI SDK does not support generative AI service yet.
104+
client_class = cls._import_client()
105+
values["client"] = client_class(**auth, **client_kwargs)
101106
# Set default compartment ID
102107
if not values.get("compartment_id"):
103108
if COMPARTMENT_OCID:

ads/llm/langchain/plugins/contant.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
from enum import Enum
77

@@ -21,7 +21,7 @@ class StrEnum(str, Enum):
2121

2222
class Task(StrEnum):
2323
TEXT_GENERATION = "text_generation"
24-
SUMMARY_TEXT = "summary_text"
24+
TEXT_SUMMARIZATION = "text_summarization"
2525

2626

2727
class LengthParam(StrEnum):
@@ -42,8 +42,3 @@ class ExtractivenessParam(StrEnum):
4242
MEDIUM = "MEDIUM"
4343
HIGH = "HIGH"
4444
AUTO = "AUTO"
45-
46-
47-
class OCIGenerativeAIModel(StrEnum):
48-
COHERE_COMMAND = "cohere.command"
49-
COHERE_COMMAND_LIGHT = "cohere.command-light"

ads/llm/langchain/plugins/embeddings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from typing import List, Optional
@@ -38,7 +38,10 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
3838
Returns:
3939
List of embeddings, one for each text.
4040
"""
41-
from oci.generative_ai.models import EmbedTextDetails, OnDemandServingMode
41+
from oci.generative_ai_inference.models import (
42+
EmbedTextDetails,
43+
OnDemandServingMode,
44+
)
4245

4346
details = EmbedTextDetails(
4447
compartment_id=self.compartment_id,

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
@@ -10,7 +10,7 @@
1010
from langchain.callbacks.manager import CallbackManagerForLLMRun
1111

1212
from ads.llm.langchain.plugins.base import BaseLLM, GenerativeAiClientModel
13-
from ads.llm.langchain.plugins.contant import *
13+
from ads.llm.langchain.plugins.contant import Task
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -32,7 +32,7 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
3232
"""
3333

3434
task: str = "text_generation"
35-
"""Indicates the task."""
35+
"""Task can be either text_generation or text_summarization."""
3636

3737
model: Optional[str] = "cohere.command"
3838
"""Model name to use."""
@@ -106,7 +106,7 @@ def _default_params(self) -> Dict[str, Any]:
106106

107107
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
108108
params = self._default_params
109-
if self.task == Task.SUMMARY_TEXT:
109+
if self.task == Task.TEXT_SUMMARIZATION:
110110
return {**params}
111111

112112
if self.stop is not None and stop is not None:
@@ -149,11 +149,7 @@ def _call(
149149
self._print_request(prompt, params)
150150

151151
try:
152-
response = (
153-
self.completion_with_retry(prompts=[prompt], **params)
154-
if self.task == Task.TEXT_GENERATION
155-
else self.completion_with_retry(input=prompt, **params)
156-
)
152+
completion = self.completion_with_retry(prompt=prompt, **params)
157153
except Exception:
158154
logger.error(
159155
"Error occur when invoking oci service api."
@@ -164,39 +160,95 @@ def _call(
164160
)
165161
raise
166162

167-
completion = self._process_response(response, params.get("num_generations", 1))
168-
self._print_response(completion, response)
169163
return completion
170164

171-
def _process_response(self, response: Any, num_generations: int = 1) -> str:
172-
if self.task == Task.SUMMARY_TEXT:
173-
return response.data.summary
165+
def _text_generation(self, request_class, serving_mode, **kwargs):
166+
from oci.generative_ai_inference.models import (
167+
GenerateTextDetails,
168+
GenerateTextResult,
169+
)
174170

175-
return (
176-
response.data.generated_texts[0][0].text
177-
if num_generations == 1
178-
else [gen.text for gen in response.data.generated_texts[0]]
171+
compartment_id = kwargs.pop("compartment_id")
172+
inference_request = request_class(**kwargs)
173+
response = self.client.generate_text(
174+
GenerateTextDetails(
175+
compartment_id=compartment_id,
176+
serving_mode=serving_mode,
177+
inference_request=inference_request,
178+
),
179+
**self.endpoint_kwargs,
180+
).data
181+
response: GenerateTextResult
182+
return response.inference_response
183+
184+
def _cohere_completion(self, serving_mode, **kwargs) -> str:
185+
from oci.generative_ai_inference.models import (
186+
CohereLlmInferenceRequest,
187+
CohereLlmInferenceResponse,
179188
)
180189

181-
def completion_with_retry(self, **kwargs: Any) -> Any:
182-
from oci.generative_ai.models import (
183-
GenerateTextDetails,
184-
OnDemandServingMode,
185-
SummarizeTextDetails,
190+
response = self._text_generation(
191+
CohereLlmInferenceRequest, serving_mode, **kwargs
186192
)
193+
response: CohereLlmInferenceResponse
194+
if kwargs.get("num_generations", 1) == 1:
195+
completion = response.generated_texts[0].text
196+
else:
197+
completion = [result.text for result in response.generated_texts]
198+
self._print_response(completion, response)
199+
return completion
187200

188-
# TODO: Add retry logic for OCI
189-
# Convert the ``model`` parameter to OCI ``ServingMode``
190-
# Note that "ServingMode` is not JSON serializable.
191-
kwargs["serving_mode"] = OnDemandServingMode(model_id=self.model)
192-
if self.task == Task.TEXT_GENERATION:
193-
return self.client.generate_text(
194-
GenerateTextDetails(**kwargs), **self.endpoint_kwargs
195-
)
201+
def _llama_completion(self, serving_mode, **kwargs) -> str:
202+
from oci.generative_ai_inference.models import (
203+
LlamaLlmInferenceRequest,
204+
LlamaLlmInferenceResponse,
205+
)
206+
207+
# truncate and stop_sequence are not supported.
208+
kwargs.pop("truncate", None)
209+
kwargs.pop("stop_sequences", None)
210+
# top_k must be >1 or -1
211+
if "top_k" in kwargs and kwargs["top_k"] == 0:
212+
kwargs.pop("top_k")
213+
214+
# top_p must be 1 when temperature is 0
215+
if kwargs.get("temperature") == 0:
216+
kwargs["top_p"] = 1
217+
218+
response = self._text_generation(
219+
LlamaLlmInferenceRequest, serving_mode, **kwargs
220+
)
221+
response: LlamaLlmInferenceResponse
222+
if kwargs.get("num_generations", 1) == 1:
223+
completion = response.choices[0].text
196224
else:
197-
return self.client.summarize_text(
198-
SummarizeTextDetails(**kwargs), **self.endpoint_kwargs
199-
)
225+
completion = [result.text for result in response.choices]
226+
self._print_response(completion, response)
227+
return completion
228+
229+
def _cohere_summarize(self, serving_mode, **kwargs) -> str:
230+
from oci.generative_ai_inference.models import SummarizeTextDetails
231+
232+
kwargs["input"] = kwargs.pop("prompt")
233+
234+
response = self.client.summarize_text(
235+
SummarizeTextDetails(serving_mode=serving_mode, **kwargs),
236+
**self.endpoint_kwargs,
237+
)
238+
return response.data.summary
239+
240+
def completion_with_retry(self, **kwargs: Any) -> Any:
241+
from oci.generative_ai_inference.models import OnDemandServingMode
242+
243+
serving_mode = OnDemandServingMode(model_id=self.model)
244+
245+
if self.task == Task.TEXT_SUMMARIZATION:
246+
return self._cohere_summarize(serving_mode, **kwargs)
247+
elif self.model.startswith("cohere"):
248+
return self._cohere_completion(serving_mode, **kwargs)
249+
elif self.model.startswith("meta.llama"):
250+
return self._llama_completion(serving_mode, **kwargs)
251+
raise ValueError(f"Model {self.model} is not supported.")
200252

201253
def batch_completion(
202254
self,
@@ -235,9 +287,9 @@ def batch_completion(
235287
responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)
236288
237289
"""
238-
if self.task == Task.SUMMARY_TEXT:
290+
if self.task == Task.TEXT_SUMMARIZATION:
239291
raise NotImplementedError(
240-
f"task={Task.SUMMARY_TEXT} does not support batch_completion. "
292+
f"task={Task.TEXT_SUMMARIZATION} does not support batch_completion. "
241293
)
242294

243295
return self._call(

docs/source/user_guide/large_language_model/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ Integration with LangChain
1919
**************************
2020
ADS is designed to work with LangChain, enabling developers to incorporate various LangChain components and models deployed on OCI seamlessly into their applications. Additionally, ADS can package LangChain applications and deploy it as a REST API endpoint using OCI Data Science Model Deployment.
2121

22+
* `Bridging cloud and conversational AI: LangChain and OCI Data Science platform <https://blogs.oracle.com/ai-and-datascience/post/cloud-conversational-ai-langchain-oci-data-science>`_
23+
* `Deploy LangChain applications as OCI model deployments <https://blogs.oracle.com/ai-and-datascience/post/deploy-langchain-application-as-model-deployment>`_
24+
2225

2326
.. admonition:: Installation
2427
:class: note

docs/source/user_guide/large_language_model/langchain_models.rst

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ To use the text generation model as LLM in LangChain:
2626
compartment_id="<compartment_ocid>",
2727
# Optionally you can specify keyword arguments for the OCI client, e.g. service_endpoint.
2828
client_kwargs={
29-
"service_endpoint": "https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
29+
"service_endpoint": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
3030
},
3131
)
3232
@@ -44,22 +44,22 @@ Here is an example of using prompt template and OCI generative AI LLM to build a
4444
map_input = RunnableParallel(text=RunnablePassthrough())
4545
# Template for the input text.
4646
template = PromptTemplate.from_template(
47-
"Translate the text into French.\nText:{text}\nFrench translation: "
47+
"Translate English into French. Do not ask any questions.\nEnglish: Hello!\nFrench: "
4848
)
4949
llm = GenerativeAI(
5050
compartment_id="<compartment_ocid>",
5151
# Optionally you can specify keyword arguments for the OCI client, e.g. service_endpoint.
5252
client_kwargs={
53-
"service_endpoint": "https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
53+
"service_endpoint": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
5454
},
5555
)
5656
5757
# Build the app as a chain
5858
translation_app = map_input | template | llm
5959
6060
# Now you have a translation app.
61-
translation_app.invoke("How are you?")
62-
# "Comment ça va?"
61+
translation_app.invoke("Hello!")
62+
# "Bonjour!"
6363
6464
Similarly, you can use the embedding model:
6565

@@ -71,7 +71,7 @@ Similarly, you can use the embedding model:
7171
compartment_id="<compartment_ocid>",
7272
# Optionally you can specify keyword arguments for the OCI client, e.g. service_endpoint.
7373
client_kwargs={
74-
"service_endpoint": "https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
74+
"service_endpoint": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
7575
},
7676
)
7777
@@ -80,6 +80,11 @@ Similarly, you can use the embedding model:
8080
Integration with Model Deployment
8181
=================================
8282

83+
.. admonition:: Available in LangChain
84+
:class: note
85+
86+
The same ``OCIModelDeploymentVLLM`` and ``ModelDeploymentTGI`` classes are also `available from LangChain <https://python.langchain.com/docs/integrations/llms/oci_model_deployment_endpoint>`_.
87+
8388
If you deploy open-source or your own LLM on OCI model deployment service using `vLLM <https://docs.vllm.ai/en/latest/>`_ or `HuggingFace TGI <https://huggingface.co/docs/text-generation-inference/index>`_ , you can use the ``ModelDeploymentVLLM`` or ``ModelDeploymentTGI`` to integrate your model with LangChain.
8489

8590
.. code-block:: python3
@@ -115,7 +120,7 @@ By default, the integration uses the same authentication method configured with
115120
compartment_id="<compartment_ocid>",
116121
# Optionally you can specify keyword arguments for the OCI client, e.g. service_endpoint.
117122
client_kwargs={
118-
"service_endpoint": "https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
123+
"service_endpoint": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
119124
},
120125
)
121126
@@ -132,6 +137,6 @@ Alternatively, you may use specific authentication for the model:
132137
compartment_id="<compartment_ocid>",
133138
# Optionally you can specify keyword arguments for the OCI client, e.g. service_endpoint.
134139
client_kwargs={
135-
"service_endpoint": "https://generativeai.aiservice.us-chicago-1.oci.oraclecloud.com"
140+
"service_endpoint": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
136141
},
137-
)
142+
)

0 commit comments

Comments
 (0)