@@ -236,9 +236,28 @@ async def generate(
236
236
timeout : float = 180.0 , # Increased timeout for more complete responses (3 minutes)
237
237
repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
238
238
top_k : int = 80 , # Added top_k parameter for better quality
239
- do_sample : bool = True # Added do_sample parameter
239
+ do_sample : bool = True , # Added do_sample parameter
240
+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
240
241
) -> str :
241
- """Generate text using the model with improved error handling"""
242
+ """
243
+ Generate text using the model with improved error handling.
244
+
245
+ Args:
246
+ prompt: The prompt to generate text from
247
+ model_id: Optional model ID to use
248
+ stream: Whether to stream the response
249
+ max_length: Maximum length of the generated text
250
+ temperature: Temperature for sampling
251
+ top_p: Top-p for nucleus sampling
252
+ timeout: Request timeout in seconds
253
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
254
+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
255
+ do_sample: Whether to use sampling instead of greedy decoding
256
+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
257
+
258
+ Returns:
259
+ The generated text as a string.
260
+ """
242
261
# Update activity timestamp
243
262
self ._update_activity ()
244
263
@@ -254,6 +273,10 @@ async def generate(
254
273
"do_sample" : do_sample
255
274
}
256
275
276
+ # Add max_time parameter if provided
277
+ if max_time is not None :
278
+ payload ["max_time" ] = max_time
279
+
257
280
if stream :
258
281
return self .stream_generate (
259
282
prompt = prompt ,
@@ -311,7 +334,8 @@ async def stream_generate(
311
334
retry_count : int = 3 , # Increased retry count for better reliability
312
335
repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
313
336
top_k : int = 80 , # Added top_k parameter for better quality
314
- do_sample : bool = True # Added do_sample parameter
337
+ do_sample : bool = True , # Added do_sample parameter
338
+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
315
339
) -> AsyncGenerator [str , None ]:
316
340
"""
317
341
Stream text generation with token-level streaming and robust error handling.
@@ -326,6 +350,8 @@ async def stream_generate(
326
350
retry_count: Number of retries for network errors
327
351
repetition_penalty: Penalty for repetition (higher values = less repetition)
328
352
top_k: Top-k for sampling (higher values = more diverse vocabulary)
353
+ do_sample: Whether to use sampling instead of greedy decoding
354
+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
329
355
330
356
Returns:
331
357
A generator that yields chunks of text as they are generated.
@@ -349,6 +375,10 @@ async def stream_generate(
349
375
"do_sample" : do_sample
350
376
}
351
377
378
+ # Add max_time parameter if provided
379
+ if max_time is not None :
380
+ payload ["max_time" ] = max_time
381
+
352
382
# Create a timeout for this specific request
353
383
request_timeout = aiohttp .ClientTimeout (total = timeout )
354
384
@@ -473,9 +503,27 @@ async def chat(
473
503
top_p : float = 0.9 ,
474
504
timeout : float = 180.0 , # Increased timeout for more complete responses (3 minutes)
475
505
repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
476
- top_k : int = 80 # Added top_k parameter for better quality
506
+ top_k : int = 80 , # Added top_k parameter for better quality
507
+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
477
508
) -> Dict [str , Any ]:
478
- """Chat completion endpoint with improved error handling"""
509
+ """
510
+ Chat completion endpoint with improved error handling.
511
+
512
+ Args:
513
+ messages: List of message dictionaries with 'role' and 'content' keys
514
+ model_id: Optional model ID to use
515
+ stream: Whether to stream the response
516
+ max_length: Maximum length of the generated text
517
+ temperature: Temperature for sampling
518
+ top_p: Top-p for nucleus sampling
519
+ timeout: Request timeout in seconds
520
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
521
+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
522
+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
523
+
524
+ Returns:
525
+ The chat completion response as a dictionary.
526
+ """
479
527
# Update activity timestamp
480
528
self ._update_activity ()
481
529
@@ -490,6 +538,10 @@ async def chat(
490
538
"top_k" : top_k
491
539
}
492
540
541
+ # Add max_time parameter if provided
542
+ if max_time is not None :
543
+ payload ["max_time" ] = max_time
544
+
493
545
if stream :
494
546
return self .stream_chat (
495
547
messages = messages ,
@@ -538,9 +590,27 @@ async def stream_chat(
538
590
timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
539
591
retry_count : int = 3 , # Increased retry count for better reliability
540
592
repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
541
- top_k : int = 80 # Added top_k parameter for better quality
593
+ top_k : int = 80 , # Added top_k parameter for better quality
594
+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
542
595
) -> AsyncGenerator [Dict [str , Any ], None ]:
543
- """Stream chat completion with robust error handling"""
596
+ """
597
+ Stream chat completion with robust error handling.
598
+
599
+ Args:
600
+ messages: List of message dictionaries with 'role' and 'content' keys
601
+ model_id: Optional model ID to use
602
+ max_length: Maximum length of the generated text
603
+ temperature: Temperature for sampling
604
+ top_p: Top-p for nucleus sampling
605
+ timeout: Request timeout in seconds
606
+ retry_count: Number of retries for network errors
607
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
608
+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
609
+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
610
+
611
+ Returns:
612
+ A generator that yields chunks of the chat completion response.
613
+ """
544
614
# Update activity timestamp
545
615
self ._update_activity ()
546
616
@@ -555,6 +625,10 @@ async def stream_chat(
555
625
"top_k" : top_k
556
626
}
557
627
628
+ # Add max_time parameter if provided
629
+ if max_time is not None :
630
+ payload ["max_time" ] = max_time
631
+
558
632
# Create a timeout for this specific request
559
633
request_timeout = aiohttp .ClientTimeout (total = timeout )
560
634
@@ -661,9 +735,26 @@ async def batch_generate(
661
735
top_p : float = 0.9 ,
662
736
timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
663
737
repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
664
- top_k : int = 80 # Added top_k parameter for better quality
738
+ top_k : int = 80 , # Added top_k parameter for better quality
739
+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
665
740
) -> Dict [str , List [str ]]:
666
- """Generate text for multiple prompts in parallel with improved error handling"""
741
+ """
742
+ Generate text for multiple prompts in parallel with improved error handling.
743
+
744
+ Args:
745
+ prompts: List of prompts to generate text from
746
+ model_id: Optional model ID to use
747
+ max_length: Maximum length of the generated text
748
+ temperature: Temperature for sampling
749
+ top_p: Top-p for nucleus sampling
750
+ timeout: Request timeout in seconds
751
+ repetition_penalty: Penalty for repetition (higher values = less repetition)
752
+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
753
+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
754
+
755
+ Returns:
756
+ Dictionary with the generated responses.
757
+ """
667
758
# Update activity timestamp
668
759
self ._update_activity ()
669
760
@@ -677,6 +768,10 @@ async def batch_generate(
677
768
"top_k" : top_k
678
769
}
679
770
771
+ # Add max_time parameter if provided
772
+ if max_time is not None :
773
+ payload ["max_time" ] = max_time
774
+
680
775
# Create a timeout for this specific request
681
776
request_timeout = aiohttp .ClientTimeout (total = timeout )
682
777
0 commit comments