Skip to content

Commit f5905d8

Browse files
committed
Update OCI Generative AI LLM model to use new APIs.
1 parent fe3aac6 commit f5905d8

File tree

4 files changed

+87
-121
lines changed

4 files changed

+87
-121
lines changed

ads/llm/langchain/plugins/base.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,36 +77,16 @@ class GenerativeAiClientModel(BaseModel):
7777
client_kwargs: Dict[str, Any] = {}
7878
"""Holds any client parameters for creating GenerativeAiClient"""
7979

80-
@staticmethod
81-
def _import_client_v1():
82-
from oci.generative_ai import GenerativeAiClient
83-
84-
return GenerativeAiClient
85-
86-
@staticmethod
87-
def _import_client_v2():
88-
from oci.generative_ai_inference import GenerativeAiInferenceClient
89-
90-
return GenerativeAiInferenceClient
91-
9280
@staticmethod
9381
def _import_client():
94-
import_methods = [
95-
GenerativeAiClientModel._import_client_v1,
96-
GenerativeAiClientModel._import_client_v2,
97-
]
98-
client_class = None
99-
for import_client_method in import_methods:
100-
try:
101-
client_class = import_client_method()
102-
except ImportError:
103-
pass
104-
if not client_class:
82+
try:
83+
from oci.generative_ai_inference import GenerativeAiInferenceClient
84+
except ImportError as ex:
10585
raise ImportError(
106-
"Could not import GenerativeAIClient or GenerativeAiInferenceClient from oci. "
86+
"Could not import GenerativeAiInferenceClient from oci. "
10787
"The OCI SDK installed does not support generative AI service."
108-
)
109-
return client_class
88+
) from ex
89+
return GenerativeAiInferenceClient
11090

11191
@root_validator()
11292
def validate_environment( # pylint: disable=no-self-argument

ads/llm/langchain/plugins/contant.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
from enum import Enum
77

@@ -21,7 +21,7 @@ class StrEnum(str, Enum):
2121

2222
class Task(StrEnum):
2323
TEXT_GENERATION = "text_generation"
24-
SUMMARY_TEXT = "summary_text"
24+
TEXT_SUMMARIZATION = "text_summarization"
2525

2626

2727
class LengthParam(StrEnum):
@@ -42,8 +42,3 @@ class ExtractivenessParam(StrEnum):
4242
MEDIUM = "MEDIUM"
4343
HIGH = "HIGH"
4444
AUTO = "AUTO"
45-
46-
47-
class OCIGenerativeAIModel(StrEnum):
48-
COHERE_COMMAND = "cohere.command"
49-
COHERE_COMMAND_LIGHT = "cohere.command-light"

ads/llm/langchain/plugins/embeddings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
3838
Returns:
3939
List of embeddings, one for each text.
4040
"""
41-
from oci.generative_ai.models import EmbedTextDetails, OnDemandServingMode
41+
from oci.generative_ai_inference.models import (
42+
EmbedTextDetails,
43+
OnDemandServingMode,
44+
)
4245

4346
details = EmbedTextDetails(
4447
compartment_id=self.compartment_id,

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 75 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
@@ -10,7 +10,7 @@
1010
from langchain.callbacks.manager import CallbackManagerForLLMRun
1111

1212
from ads.llm.langchain.plugins.base import BaseLLM, GenerativeAiClientModel
13-
from ads.llm.langchain.plugins.contant import *
13+
from ads.llm.langchain.plugins.contant import Task
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -32,7 +32,7 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
3232
"""
3333

3434
task: str = "text_generation"
35-
"""Indicates the task."""
35+
"""Task can be either text_generation or text_summarization."""
3636

3737
model: Optional[str] = "cohere.command"
3838
"""Model name to use."""
@@ -106,7 +106,7 @@ def _default_params(self) -> Dict[str, Any]:
106106

107107
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
108108
params = self._default_params
109-
if self.task == Task.SUMMARY_TEXT:
109+
if self.task == Task.TEXT_SUMMARIZATION:
110110
return {**params}
111111

112112
if self.stop is not None and stop is not None:
@@ -149,11 +149,7 @@ def _call(
149149
self._print_request(prompt, params)
150150

151151
try:
152-
completion = (
153-
self.completion_with_retry(prompt=prompt, **params)
154-
if self.task == Task.TEXT_GENERATION
155-
else self.completion_with_retry(input=prompt, **params)
156-
)
152+
completion = self.completion_with_retry(prompt=prompt, **params)
157153
except Exception:
158154
logger.error(
159155
"Error occur when invoking oci service api."
@@ -164,103 +160,95 @@ def _call(
164160
)
165161
raise
166162

167-
# completion = self._process_response(response, params.get("num_generations", 1))
168-
# self._print_response(completion, response)
169163
return completion
170164

171-
def _process_response(self, response: Any, num_generations: int = 1) -> str:
172-
if self.task == Task.SUMMARY_TEXT:
173-
return response.data.summary
174-
175-
return (
176-
response.data.generated_texts[0][0].text
177-
if num_generations == 1
178-
else [gen.text for gen in response.data.generated_texts[0]]
165+
def _text_generation(self, request_class, serving_mode, **kwargs):
166+
from oci.generative_ai_inference.models import (
167+
GenerateTextDetails,
168+
GenerateTextResult,
179169
)
180170

181-
def _completion_with_retry_v1(self, **kwargs: Any):
182-
from oci.generative_ai.models import (
183-
GenerateTextDetails,
184-
OnDemandServingMode,
185-
SummarizeTextDetails,
171+
compartment_id = kwargs.pop("compartment_id")
172+
inference_request = request_class(**kwargs)
173+
response = self.client.generate_text(
174+
GenerateTextDetails(
175+
compartment_id=compartment_id,
176+
serving_mode=serving_mode,
177+
inference_request=inference_request,
178+
),
179+
**self.endpoint_kwargs,
180+
).data
181+
response: GenerateTextResult
182+
return response.inference_response
183+
184+
def _cohere_completion(self, serving_mode, **kwargs) -> str:
185+
from oci.generative_ai_inference.models import (
186+
CohereLlmInferenceRequest,
187+
CohereLlmInferenceResponse,
186188
)
187189

188-
# TODO: Add retry logic for OCI
189-
# Convert the ``model`` parameter to OCI ``ServingMode``
190-
# Note that "ServingMode` is not JSON serializable.
191-
kwargs["prompts"] = [kwargs.pop("prompt")]
192-
kwargs["serving_mode"] = OnDemandServingMode(model_id=self.model)
193-
if self.task == Task.TEXT_GENERATION:
194-
response = self.client.generate_text(
195-
GenerateTextDetails(**kwargs), **self.endpoint_kwargs
196-
)
197-
if kwargs.get("num_generations", 1) == 1:
198-
completion = response.data.generated_texts[0][0].text
199-
else:
200-
completion = [gen.text for gen in response.data.generated_texts[0]]
190+
response = self._text_generation(
191+
CohereLlmInferenceRequest, serving_mode, **kwargs
192+
)
193+
response: CohereLlmInferenceResponse
194+
if kwargs.get("num_generations", 1) == 1:
195+
completion = response.generated_texts[0].text
201196
else:
202-
response = self.client.summarize_text(
203-
SummarizeTextDetails(**kwargs), **self.endpoint_kwargs
204-
)
205-
completion = response.data.summary
197+
completion = [result.text for result in response.generated_texts]
206198
self._print_response(completion, response)
207199
return completion
208200

209-
def _completion_with_retry_v2(self, **kwargs: Any):
201+
def _llama_completion(self, serving_mode, **kwargs) -> str:
210202
from oci.generative_ai_inference.models import (
211-
GenerateTextDetails,
212-
OnDemandServingMode,
213-
SummarizeTextDetails,
214-
CohereLlmInferenceRequest,
215203
LlamaLlmInferenceRequest,
204+
LlamaLlmInferenceResponse,
216205
)
217206

218-
request_class_mapping = {
219-
"cohere": CohereLlmInferenceRequest,
220-
"llama": LlamaLlmInferenceRequest,
221-
}
207+
# truncate and stop_sequence are not supported.
208+
kwargs.pop("truncate", None)
209+
kwargs.pop("stop_sequences", None)
210+
# top_k must be >1 or -1
211+
if "top_k" in kwargs and kwargs["top_k"] == 0:
212+
kwargs.pop("top_k")
222213

223-
request_class = None
224-
for prefix, oci_request_class in request_class_mapping.items():
225-
if self.model.startswith(prefix):
226-
request_class = oci_request_class
227-
if not request_class:
228-
raise ValueError(f"Model {self.model} is not supported.")
229-
230-
if self.model.startswith("llama"):
231-
kwargs.pop("truncate", None)
232-
kwargs.pop("stop_sequences", None)
233-
234-
serving_mode = OnDemandServingMode(model_id=self.model)
235-
if self.task == Task.TEXT_GENERATION:
236-
compartment_id = kwargs.pop("compartment_id")
237-
inference_request = request_class(**kwargs)
238-
response = self.client.generate_text(
239-
GenerateTextDetails(
240-
compartment_id=compartment_id,
241-
serving_mode=serving_mode,
242-
inference_request=inference_request,
243-
),
244-
**self.endpoint_kwargs,
245-
)
246-
if kwargs.get("num_generations", 1) == 1:
247-
completion = response.data.inference_response.generated_texts[0].text
248-
else:
249-
completion = [gen.text for gen in response.data.generated_texts]
214+
# top_p must be 1 when temperature is 0
215+
if kwargs.get("temperature") == 0:
216+
kwargs["top_p"] = 1
250217

218+
response = self._text_generation(
219+
LlamaLlmInferenceRequest, serving_mode, **kwargs
220+
)
221+
response: LlamaLlmInferenceResponse
222+
if kwargs.get("num_generations", 1) == 1:
223+
completion = response.choices[0].text
251224
else:
252-
response = self.client.summarize_text(
253-
SummarizeTextDetails(serving_mode=serving_mode, **kwargs),
254-
**self.endpoint_kwargs,
255-
)
256-
completion = response.data.summary
225+
completion = [result.text for result in response.choices]
257226
self._print_response(completion, response)
258227
return completion
259228

229+
def _cohere_summarize(self, serving_mode, **kwargs) -> str:
230+
from oci.generative_ai_inference.models import SummarizeTextDetails
231+
232+
kwargs["input"] = kwargs.pop("prompt")
233+
234+
response = self.client.summarize_text(
235+
SummarizeTextDetails(serving_mode=serving_mode, **kwargs),
236+
**self.endpoint_kwargs,
237+
)
238+
return response.data.summary
239+
260240
def completion_with_retry(self, **kwargs: Any) -> Any:
261-
if self.client.__class__.__name__ == "GenerativeAiClient":
262-
return self._completion_with_retry_v1(**kwargs)
263-
return self._completion_with_retry_v2(**kwargs)
241+
from oci.generative_ai_inference.models import OnDemandServingMode
242+
243+
serving_mode = OnDemandServingMode(model_id=self.model)
244+
245+
if self.task == Task.TEXT_SUMMARIZATION:
246+
return self._cohere_summarize(serving_mode, **kwargs)
247+
elif self.model.startswith("cohere"):
248+
return self._cohere_completion(serving_mode, **kwargs)
249+
elif self.model.startswith("meta.llama"):
250+
return self._llama_completion(serving_mode, **kwargs)
251+
raise ValueError(f"Model {self.model} is not supported.")
264252

265253
def batch_completion(
266254
self,
@@ -299,9 +287,9 @@ def batch_completion(
299287
responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)
300288
301289
"""
302-
if self.task == Task.SUMMARY_TEXT:
290+
if self.task == Task.TEXT_SUMMARIZATION:
303291
raise NotImplementedError(
304-
f"task={Task.SUMMARY_TEXT} does not support batch_completion. "
292+
f"task={Task.TEXT_SUMMARIZATION} does not support batch_completion. "
305293
)
306294

307295
return self._call(

0 commit comments

Comments
 (0)