@@ -149,8 +149,8 @@ def _call(
149
149
self ._print_request (prompt , params )
150
150
151
151
try :
152
- response = (
153
- self .completion_with_retry (prompts = [ prompt ] , ** params )
152
+ completion = (
153
+ self .completion_with_retry (prompt = prompt , ** params )
154
154
if self .task == Task .TEXT_GENERATION
155
155
else self .completion_with_retry (input = prompt , ** params )
156
156
)
@@ -164,8 +164,8 @@ def _call(
164
164
)
165
165
raise
166
166
167
- completion = self ._process_response (response , params .get ("num_generations" , 1 ))
168
- self ._print_response (completion , response )
167
+ # completion = self._process_response(response, params.get("num_generations", 1))
168
+ # self._print_response(completion, response)
169
169
return completion
170
170
171
171
def _process_response (self , response : Any , num_generations : int = 1 ) -> str :
@@ -178,7 +178,7 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
178
178
else [gen .text for gen in response .data .generated_texts [0 ]]
179
179
)
180
180
181
- def completion_with_retry (self , ** kwargs : Any ) -> Any :
181
+ def _completion_with_retry_v1 (self , ** kwargs : Any ):
182
182
from oci .generative_ai .models import (
183
183
GenerateTextDetails ,
184
184
OnDemandServingMode ,
@@ -188,15 +188,79 @@ def completion_with_retry(self, **kwargs: Any) -> Any:
188
188
# TODO: Add retry logic for OCI
189
189
# Convert the ``model`` parameter to OCI ``ServingMode``
190
190
# Note that "ServingMode` is not JSON serializable.
191
+ kwargs ["prompts" ] = [kwargs .pop ("prompt" )]
191
192
kwargs ["serving_mode" ] = OnDemandServingMode (model_id = self .model )
192
193
if self .task == Task .TEXT_GENERATION :
193
- return self .client .generate_text (
194
+ response = self .client .generate_text (
194
195
GenerateTextDetails (** kwargs ), ** self .endpoint_kwargs
195
196
)
197
+ if kwargs .get ("num_generations" , 1 ) == 1 :
198
+ completion = response .data .generated_texts [0 ][0 ].text
199
+ else :
200
+ completion = [gen .text for gen in response .data .generated_texts [0 ]]
196
201
else :
197
- return self .client .summarize_text (
202
+ response = self .client .summarize_text (
198
203
SummarizeTextDetails (** kwargs ), ** self .endpoint_kwargs
199
204
)
205
+ completion = response .data .summary
206
+ self ._print_response (completion , response )
207
+ return completion
208
+
209
+ def _completion_with_retry_v2 (self , ** kwargs : Any ):
210
+ from oci .generative_ai_inference .models import (
211
+ GenerateTextDetails ,
212
+ OnDemandServingMode ,
213
+ SummarizeTextDetails ,
214
+ CohereLlmInferenceRequest ,
215
+ LlamaLlmInferenceRequest ,
216
+ )
217
+
218
+ request_class_mapping = {
219
+ "cohere" : CohereLlmInferenceRequest ,
220
+ "llama" : LlamaLlmInferenceRequest ,
221
+ }
222
+
223
+ request_class = None
224
+ for prefix , oci_request_class in request_class_mapping .items ():
225
+ if self .model .startswith (prefix ):
226
+ request_class = oci_request_class
227
+ if not request_class :
228
+ raise ValueError (f"Model { self .model } is not supported." )
229
+
230
+ if self .model .startswith ("llama" ):
231
+ kwargs .pop ("truncate" , None )
232
+ kwargs .pop ("stop_sequences" , None )
233
+
234
+ serving_mode = OnDemandServingMode (model_id = self .model )
235
+ if self .task == Task .TEXT_GENERATION :
236
+ compartment_id = kwargs .pop ("compartment_id" )
237
+ inference_request = request_class (** kwargs )
238
+ response = self .client .generate_text (
239
+ GenerateTextDetails (
240
+ compartment_id = compartment_id ,
241
+ serving_mode = serving_mode ,
242
+ inference_request = inference_request ,
243
+ ),
244
+ ** self .endpoint_kwargs ,
245
+ )
246
+ if kwargs .get ("num_generations" , 1 ) == 1 :
247
+ completion = response .data .inference_response .generated_texts [0 ].text
248
+ else :
249
+ completion = [gen .text for gen in response .data .generated_texts ]
250
+
251
+ else :
252
+ response = self .client .summarize_text (
253
+ SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
254
+ ** self .endpoint_kwargs ,
255
+ )
256
+ completion = response .data .summary
257
+ self ._print_response (completion , response )
258
+ return completion
259
+
260
+ def completion_with_retry (self , ** kwargs : Any ) -> Any :
261
+ if self .client .__class__ .__name__ == "GenerativeAiClient" :
262
+ return self ._completion_with_retry_v1 (** kwargs )
263
+ return self ._completion_with_retry_v2 (** kwargs )
200
264
201
265
def batch_completion (
202
266
self ,
0 commit comments