Skip to content

Commit fe3aac6

Browse files
committed
Support OCI generative AI inference client.
1 parent 37e116e commit fe3aac6

File tree

2 files changed

+106
-17
lines changed

2 files changed

+106
-17
lines changed

ads/llm/langchain/plugins/base.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,27 +77,52 @@ class GenerativeAiClientModel(BaseModel):
7777
client_kwargs: Dict[str, Any] = {}
7878
"""Holds any client parameters for creating GenerativeAiClient"""
7979

80+
@staticmethod
81+
def _import_client_v1():
82+
from oci.generative_ai import GenerativeAiClient
83+
84+
return GenerativeAiClient
85+
86+
@staticmethod
87+
def _import_client_v2():
88+
from oci.generative_ai_inference import GenerativeAiInferenceClient
89+
90+
return GenerativeAiInferenceClient
91+
92+
@staticmethod
93+
def _import_client():
94+
import_methods = [
95+
GenerativeAiClientModel._import_client_v1,
96+
GenerativeAiClientModel._import_client_v2,
97+
]
98+
client_class = None
99+
for import_client_method in import_methods:
100+
try:
101+
client_class = import_client_method()
102+
except ImportError:
103+
pass
104+
if not client_class:
105+
raise ImportError(
106+
"Could not import GenerativeAIClient or GenerativeAiInferenceClient from oci. "
107+
"The OCI SDK installed does not support generative AI service."
108+
)
109+
return client_class
110+
80111
@root_validator()
81112
def validate_environment( # pylint: disable=no-self-argument
82113
cls, values: Dict
83114
) -> Dict:
84115
"""Validate that python package exists in environment."""
85-
try:
86-
# Import the GenerativeAIClient here so that there will be no error when user import ads.llm
87-
# and the install OCI SDK does not support generative AI service yet.
88-
from oci.generative_ai import GenerativeAiClient
89-
except ImportError as ex:
90-
raise ImportError(
91-
"Could not import GenerativeAIClient from oci. "
92-
"The OCI SDK installed does not support generative AI service."
93-
) from ex
94116
# Initialize client only if user does not pass in client.
95117
# Users may choose to initialize the OCI client by themselves and pass it into this model.
96118
if not values.get("client"):
97119
auth = values.get("auth", {})
98120
client_kwargs = auth.get("client_kwargs") or {}
99121
client_kwargs.update(values["client_kwargs"])
100-
values["client"] = GenerativeAiClient(**auth, **client_kwargs)
122+
# Import the GenerativeAIClient here so that there will be no error when user import ads.llm
123+
# and the install OCI SDK does not support generative AI service yet.
124+
client_class = cls._import_client()
125+
values["client"] = client_class(**auth, **client_kwargs)
101126
# Set default compartment ID
102127
if not values.get("compartment_id"):
103128
if COMPARTMENT_OCID:

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def _call(
149149
self._print_request(prompt, params)
150150

151151
try:
152-
response = (
153-
self.completion_with_retry(prompts=[prompt], **params)
152+
completion = (
153+
self.completion_with_retry(prompt=prompt, **params)
154154
if self.task == Task.TEXT_GENERATION
155155
else self.completion_with_retry(input=prompt, **params)
156156
)
@@ -164,8 +164,8 @@ def _call(
164164
)
165165
raise
166166

167-
completion = self._process_response(response, params.get("num_generations", 1))
168-
self._print_response(completion, response)
167+
# completion = self._process_response(response, params.get("num_generations", 1))
168+
# self._print_response(completion, response)
169169
return completion
170170

171171
def _process_response(self, response: Any, num_generations: int = 1) -> str:
@@ -178,7 +178,7 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
178178
else [gen.text for gen in response.data.generated_texts[0]]
179179
)
180180

181-
def completion_with_retry(self, **kwargs: Any) -> Any:
181+
def _completion_with_retry_v1(self, **kwargs: Any):
182182
from oci.generative_ai.models import (
183183
GenerateTextDetails,
184184
OnDemandServingMode,
@@ -188,15 +188,79 @@ def completion_with_retry(self, **kwargs: Any) -> Any:
188188
# TODO: Add retry logic for OCI
189189
# Convert the ``model`` parameter to OCI ``ServingMode``
190190
# Note that "ServingMode` is not JSON serializable.
191+
kwargs["prompts"] = [kwargs.pop("prompt")]
191192
kwargs["serving_mode"] = OnDemandServingMode(model_id=self.model)
192193
if self.task == Task.TEXT_GENERATION:
193-
return self.client.generate_text(
194+
response = self.client.generate_text(
194195
GenerateTextDetails(**kwargs), **self.endpoint_kwargs
195196
)
197+
if kwargs.get("num_generations", 1) == 1:
198+
completion = response.data.generated_texts[0][0].text
199+
else:
200+
completion = [gen.text for gen in response.data.generated_texts[0]]
196201
else:
197-
return self.client.summarize_text(
202+
response = self.client.summarize_text(
198203
SummarizeTextDetails(**kwargs), **self.endpoint_kwargs
199204
)
205+
completion = response.data.summary
206+
self._print_response(completion, response)
207+
return completion
208+
209+
def _completion_with_retry_v2(self, **kwargs: Any):
210+
from oci.generative_ai_inference.models import (
211+
GenerateTextDetails,
212+
OnDemandServingMode,
213+
SummarizeTextDetails,
214+
CohereLlmInferenceRequest,
215+
LlamaLlmInferenceRequest,
216+
)
217+
218+
request_class_mapping = {
219+
"cohere": CohereLlmInferenceRequest,
220+
"llama": LlamaLlmInferenceRequest,
221+
}
222+
223+
request_class = None
224+
for prefix, oci_request_class in request_class_mapping.items():
225+
if self.model.startswith(prefix):
226+
request_class = oci_request_class
227+
if not request_class:
228+
raise ValueError(f"Model {self.model} is not supported.")
229+
230+
if self.model.startswith("llama"):
231+
kwargs.pop("truncate", None)
232+
kwargs.pop("stop_sequences", None)
233+
234+
serving_mode = OnDemandServingMode(model_id=self.model)
235+
if self.task == Task.TEXT_GENERATION:
236+
compartment_id = kwargs.pop("compartment_id")
237+
inference_request = request_class(**kwargs)
238+
response = self.client.generate_text(
239+
GenerateTextDetails(
240+
compartment_id=compartment_id,
241+
serving_mode=serving_mode,
242+
inference_request=inference_request,
243+
),
244+
**self.endpoint_kwargs,
245+
)
246+
if kwargs.get("num_generations", 1) == 1:
247+
completion = response.data.inference_response.generated_texts[0].text
248+
else:
249+
completion = [gen.text for gen in response.data.generated_texts]
250+
251+
else:
252+
response = self.client.summarize_text(
253+
SummarizeTextDetails(serving_mode=serving_mode, **kwargs),
254+
**self.endpoint_kwargs,
255+
)
256+
completion = response.data.summary
257+
self._print_response(completion, response)
258+
return completion
259+
260+
def completion_with_retry(self, **kwargs: Any) -> Any:
261+
if self.client.__class__.__name__ == "GenerativeAiClient":
262+
return self._completion_with_retry_v1(**kwargs)
263+
return self._completion_with_retry_v2(**kwargs)
200264

201265
def batch_completion(
202266
self,

0 commit comments

Comments
 (0)