1
1
#!/usr/bin/env python
2
2
# -*- coding: utf-8 -*--
3
3
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
7
import logging
10
10
from langchain .callbacks .manager import CallbackManagerForLLMRun
11
11
12
12
from ads .llm .langchain .plugins .base import BaseLLM , GenerativeAiClientModel
13
- from ads .llm .langchain .plugins .contant import *
13
+ from ads .llm .langchain .plugins .contant import Task
14
14
15
15
logger = logging .getLogger (__name__ )
16
16
@@ -32,7 +32,7 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
32
32
"""
33
33
34
34
task : str = "text_generation"
35
- """Indicates the task ."""
35
+ """Task can be either text_generation or text_summarization ."""
36
36
37
37
model : Optional [str ] = "cohere.command"
38
38
"""Model name to use."""
@@ -106,7 +106,7 @@ def _default_params(self) -> Dict[str, Any]:
106
106
107
107
def _invocation_params (self , stop : Optional [List [str ]], ** kwargs : Any ) -> dict :
108
108
params = self ._default_params
109
- if self .task == Task .SUMMARY_TEXT :
109
+ if self .task == Task .TEXT_SUMMARIZATION :
110
110
return {** params }
111
111
112
112
if self .stop is not None and stop is not None :
@@ -149,11 +149,7 @@ def _call(
149
149
self ._print_request (prompt , params )
150
150
151
151
try :
152
- completion = (
153
- self .completion_with_retry (prompt = prompt , ** params )
154
- if self .task == Task .TEXT_GENERATION
155
- else self .completion_with_retry (input = prompt , ** params )
156
- )
152
+ completion = self .completion_with_retry (prompt = prompt , ** params )
157
153
except Exception :
158
154
logger .error (
159
155
"Error occur when invoking oci service api."
@@ -164,103 +160,95 @@ def _call(
164
160
)
165
161
raise
166
162
167
- # completion = self._process_response(response, params.get("num_generations", 1))
168
- # self._print_response(completion, response)
169
163
return completion
170
164
171
- def _process_response (self , response : Any , num_generations : int = 1 ) -> str :
172
- if self .task == Task .SUMMARY_TEXT :
173
- return response .data .summary
174
-
175
- return (
176
- response .data .generated_texts [0 ][0 ].text
177
- if num_generations == 1
178
- else [gen .text for gen in response .data .generated_texts [0 ]]
165
+ def _text_generation (self , request_class , serving_mode , ** kwargs ):
166
+ from oci .generative_ai_inference .models import (
167
+ GenerateTextDetails ,
168
+ GenerateTextResult ,
179
169
)
180
170
181
- def _completion_with_retry_v1 (self , ** kwargs : Any ):
182
- from oci .generative_ai .models import (
183
- GenerateTextDetails ,
184
- OnDemandServingMode ,
185
- SummarizeTextDetails ,
171
+ compartment_id = kwargs .pop ("compartment_id" )
172
+ inference_request = request_class (** kwargs )
173
+ response = self .client .generate_text (
174
+ GenerateTextDetails (
175
+ compartment_id = compartment_id ,
176
+ serving_mode = serving_mode ,
177
+ inference_request = inference_request ,
178
+ ),
179
+ ** self .endpoint_kwargs ,
180
+ ).data
181
+ response : GenerateTextResult
182
+ return response .inference_response
183
+
184
+ def _cohere_completion (self , serving_mode , ** kwargs ) -> str :
185
+ from oci .generative_ai_inference .models import (
186
+ CohereLlmInferenceRequest ,
187
+ CohereLlmInferenceResponse ,
186
188
)
187
189
188
- # TODO: Add retry logic for OCI
189
- # Convert the ``model`` parameter to OCI ``ServingMode``
190
- # Note that "ServingMode` is not JSON serializable.
191
- kwargs ["prompts" ] = [kwargs .pop ("prompt" )]
192
- kwargs ["serving_mode" ] = OnDemandServingMode (model_id = self .model )
193
- if self .task == Task .TEXT_GENERATION :
194
- response = self .client .generate_text (
195
- GenerateTextDetails (** kwargs ), ** self .endpoint_kwargs
196
- )
197
- if kwargs .get ("num_generations" , 1 ) == 1 :
198
- completion = response .data .generated_texts [0 ][0 ].text
199
- else :
200
- completion = [gen .text for gen in response .data .generated_texts [0 ]]
190
+ response = self ._text_generation (
191
+ CohereLlmInferenceRequest , serving_mode , ** kwargs
192
+ )
193
+ response : CohereLlmInferenceResponse
194
+ if kwargs .get ("num_generations" , 1 ) == 1 :
195
+ completion = response .generated_texts [0 ].text
201
196
else :
202
- response = self .client .summarize_text (
203
- SummarizeTextDetails (** kwargs ), ** self .endpoint_kwargs
204
- )
205
- completion = response .data .summary
197
+ completion = [result .text for result in response .generated_texts ]
206
198
self ._print_response (completion , response )
207
199
return completion
208
200
209
- def _completion_with_retry_v2 (self , ** kwargs : Any ) :
201
+ def _llama_completion (self , serving_mode , ** kwargs ) -> str :
210
202
from oci .generative_ai_inference .models import (
211
- GenerateTextDetails ,
212
- OnDemandServingMode ,
213
- SummarizeTextDetails ,
214
- CohereLlmInferenceRequest ,
215
203
LlamaLlmInferenceRequest ,
204
+ LlamaLlmInferenceResponse ,
216
205
)
217
206
218
- request_class_mapping = {
219
- "cohere" : CohereLlmInferenceRequest ,
220
- "llama" : LlamaLlmInferenceRequest ,
221
- }
207
+ # truncate and stop_sequence are not supported.
208
+ kwargs .pop ("truncate" , None )
209
+ kwargs .pop ("stop_sequences" , None )
210
+ # top_k must be >1 or -1
211
+ if "top_k" in kwargs and kwargs ["top_k" ] == 0 :
212
+ kwargs .pop ("top_k" )
222
213
223
- request_class = None
224
- for prefix , oci_request_class in request_class_mapping .items ():
225
- if self .model .startswith (prefix ):
226
- request_class = oci_request_class
227
- if not request_class :
228
- raise ValueError (f"Model { self .model } is not supported." )
229
-
230
- if self .model .startswith ("llama" ):
231
- kwargs .pop ("truncate" , None )
232
- kwargs .pop ("stop_sequences" , None )
233
-
234
- serving_mode = OnDemandServingMode (model_id = self .model )
235
- if self .task == Task .TEXT_GENERATION :
236
- compartment_id = kwargs .pop ("compartment_id" )
237
- inference_request = request_class (** kwargs )
238
- response = self .client .generate_text (
239
- GenerateTextDetails (
240
- compartment_id = compartment_id ,
241
- serving_mode = serving_mode ,
242
- inference_request = inference_request ,
243
- ),
244
- ** self .endpoint_kwargs ,
245
- )
246
- if kwargs .get ("num_generations" , 1 ) == 1 :
247
- completion = response .data .inference_response .generated_texts [0 ].text
248
- else :
249
- completion = [gen .text for gen in response .data .generated_texts ]
214
+ # top_p must be 1 when temperature is 0
215
+ if kwargs .get ("temperature" ) == 0 :
216
+ kwargs ["top_p" ] = 1
250
217
218
+ response = self ._text_generation (
219
+ LlamaLlmInferenceRequest , serving_mode , ** kwargs
220
+ )
221
+ response : LlamaLlmInferenceResponse
222
+ if kwargs .get ("num_generations" , 1 ) == 1 :
223
+ completion = response .choices [0 ].text
251
224
else :
252
- response = self .client .summarize_text (
253
- SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
254
- ** self .endpoint_kwargs ,
255
- )
256
- completion = response .data .summary
225
+ completion = [result .text for result in response .choices ]
257
226
self ._print_response (completion , response )
258
227
return completion
259
228
229
+ def _cohere_summarize (self , serving_mode , ** kwargs ) -> str :
230
+ from oci .generative_ai_inference .models import SummarizeTextDetails
231
+
232
+ kwargs ["input" ] = kwargs .pop ("prompt" )
233
+
234
+ response = self .client .summarize_text (
235
+ SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
236
+ ** self .endpoint_kwargs ,
237
+ )
238
+ return response .data .summary
239
+
260
240
def completion_with_retry (self , ** kwargs : Any ) -> Any :
261
- if self .client .__class__ .__name__ == "GenerativeAiClient" :
262
- return self ._completion_with_retry_v1 (** kwargs )
263
- return self ._completion_with_retry_v2 (** kwargs )
241
+ from oci .generative_ai_inference .models import OnDemandServingMode
242
+
243
+ serving_mode = OnDemandServingMode (model_id = self .model )
244
+
245
+ if self .task == Task .TEXT_SUMMARIZATION :
246
+ return self ._cohere_summarize (serving_mode , ** kwargs )
247
+ elif self .model .startswith ("cohere" ):
248
+ return self ._cohere_completion (serving_mode , ** kwargs )
249
+ elif self .model .startswith ("meta.llama" ):
250
+ return self ._llama_completion (serving_mode , ** kwargs )
251
+ raise ValueError (f"Model { self .model } is not supported." )
264
252
265
253
def batch_completion (
266
254
self ,
@@ -299,9 +287,9 @@ def batch_completion(
299
287
responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)
300
288
301
289
"""
302
- if self .task == Task .SUMMARY_TEXT :
290
+ if self .task == Task .TEXT_SUMMARIZATION :
303
291
raise NotImplementedError (
304
- f"task={ Task .SUMMARY_TEXT } does not support batch_completion. "
292
+ f"task={ Task .TEXT_SUMMARIZATION } does not support batch_completion. "
305
293
)
306
294
307
295
return self ._call (
0 commit comments