@@ -231,6 +231,14 @@ async def load_model(self, model_id: str) -> bool:
231
231
gc .collect ()
232
232
log_model_unloaded (prev_model )
233
233
234
+ # Set CUDA memory allocation configuration to avoid fragmentation
235
+ os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
236
+
237
+ # Configure torch.compile to suppress errors and fall back to eager mode
238
+ if hasattr (torch , "_dynamo" ):
239
+ torch ._dynamo .config .suppress_errors = True
240
+ logger .info ("Configured torch._dynamo to suppress errors and fall back to eager mode" )
241
+
234
242
hf_token = os .getenv ("HF_TOKEN" )
235
243
236
244
# Check quantization settings from environment variables
@@ -270,29 +278,30 @@ async def load_model(self, model_id: str) -> bool:
270
278
self .model = self ._apply_optimizations (self .model )
271
279
272
280
# Try to compile the model for faster inference if PyTorch version supports it
281
+ # But with more robust error handling
273
282
try :
274
- if torch .__version__ >= "2.0.0" and torch .cuda .is_available ():
275
- logger .info ("Attempting to compile model with torch.compile() for faster inference..." )
276
- # Use a separate thread to avoid blocking
277
- import threading
278
-
279
- def compile_model ():
283
+ if hasattr (torch , 'compile' ) and torch .cuda .is_available ():
284
+ # Only attempt compilation if we have enough GPU memory
285
+ free_memory = torch .cuda .get_device_properties (0 ).total_memory - torch .cuda .memory_allocated ()
286
+ if free_memory > 2 * 1024 * 1024 * 1024 : # Only compile if we have >2GB free
287
+ logger .info ("Attempting to compile model with torch.compile() for faster inference..." )
280
288
try :
281
- # Only compile the forward method for generation
282
- self .model .forward = torch .compile (
283
- self .model .forward ,
284
- mode = "reduce-overhead" , # Use reduce-overhead mode for faster compilation
285
- fullgraph = False # Partial graph compilation for better compatibility
289
+ # Use a safer compilation mode
290
+ self .model = torch .compile (
291
+ self .model ,
292
+ mode = "reduce-overhead" ,
293
+ fullgraph = False ,
294
+ dynamic = True # Better for variable input sizes
286
295
)
287
296
self .compiled_model = True
288
297
logger .info (f"{ Fore .GREEN } Model successfully compiled with torch.compile(){ Style .RESET_ALL } " )
289
- except Exception as e :
290
- logger .warning (f"Could not compile model : { str (e ) } . Continuing with uncompiled model. " )
291
-
292
- # Start compilation in background
293
- threading . Thread ( target = compile_model ). start ( )
298
+ except Exception as compile_error :
299
+ logger .warning (f"Model compilation failed with specific error : { str (compile_error ) } " )
300
+ logger . info ( "Continuing with uncompiled model" )
301
+ else :
302
+ logger . info ( "Skipping model compilation due to limited GPU memory" )
294
303
except Exception as e :
295
- logger .warning (f"Torch compile not available : { str (e )} . Continuing with standard model." )
304
+ logger .warning (f"Could not compile model : { str (e )} . Continuing with uncompiled model." )
296
305
297
306
# Capture model parameters after loading
298
307
model_architecture = self .model .config .architectures [0 ] if hasattr (self .model .config , 'architectures' ) else 'Unknown'
@@ -454,41 +463,62 @@ async def generate(
454
463
if stream :
455
464
return self .async_stream_generate (inputs , gen_params )
456
465
466
+ # Check if we need to clear CUDA cache before generation
467
+ if torch .cuda .is_available ():
468
+ current_mem = torch .cuda .memory_allocated () / (1024 * 1024 * 1024 ) # GB
469
+ total_mem = torch .cuda .get_device_properties (0 ).total_memory / (1024 * 1024 * 1024 ) # GB
470
+ if current_mem > 0.8 * total_mem : # If using >80% of GPU memory
471
+ # Clear cache to avoid OOM
472
+ torch .cuda .empty_cache ()
473
+ logger .info ("Cleared CUDA cache before generation to avoid out of memory error" )
474
+
457
475
with torch .no_grad ():
458
- generate_params = {
459
- ** inputs ,
460
- "max_new_tokens" : gen_params ["max_length" ],
461
- "temperature" : gen_params ["temperature" ],
462
- "top_p" : gen_params ["top_p" ],
463
- "do_sample" : True ,
464
- "pad_token_id" : self .tokenizer .eos_token_id
465
- }
476
+ try :
477
+ generate_params = {
478
+ ** inputs ,
479
+ "max_new_tokens" : gen_params ["max_length" ],
480
+ "temperature" : gen_params ["temperature" ],
481
+ "top_p" : gen_params ["top_p" ],
482
+ "do_sample" : True ,
483
+ "pad_token_id" : self .tokenizer .eos_token_id ,
484
+ # Fix the early stopping warning by setting num_beams explicitly
485
+ "num_beams" : 1
486
+ }
466
487
467
- # Add optional parameters if present in gen_params
468
- if "top_k" in gen_params :
469
- generate_params ["top_k" ] = gen_params ["top_k" ]
470
- if "repetition_penalty" in gen_params :
471
- generate_params ["repetition_penalty" ] = gen_params ["repetition_penalty" ]
472
-
473
- # Add early stopping for faster generation
474
- generate_params ["early_stopping" ] = True
475
-
476
- # Add batch size for faster generation (process multiple tokens at once)
477
- generate_params ["num_return_sequences" ] = 1
478
-
479
- # Set a reasonable max time for generation to prevent hanging
480
- if "max_time" not in generate_params and not stream :
481
- generate_params ["max_time" ] = 30.0 # 30 seconds max for generation
482
-
483
- # Use efficient attention implementation if available
484
- if hasattr (self .model .config , "attn_implementation" ):
485
- generate_params ["attn_implementation" ] = "flash_attention_2"
488
+ # Add optional parameters if present in gen_params
489
+ if "top_k" in gen_params :
490
+ generate_params ["top_k" ] = gen_params ["top_k" ]
491
+ if "repetition_penalty" in gen_params :
492
+ generate_params ["repetition_penalty" ] = gen_params ["repetition_penalty" ]
493
+
494
+ # Set a reasonable max time for generation to prevent hanging
495
+ if "max_time" not in generate_params and not stream :
496
+ generate_params ["max_time" ] = 30.0 # 30 seconds max for generation
497
+
498
+ # Use efficient attention implementation if available
499
+ if hasattr (self .model .config , "attn_implementation" ):
500
+ generate_params ["attn_implementation" ] = "flash_attention_2"
486
501
487
- # Generate text
488
- start_time = time .time ()
489
- outputs = self .model .generate (** generate_params )
490
- generation_time = time .time () - start_time
491
- logger .info (f"Generation completed in { generation_time :.2f} seconds" )
502
+ # Generate text
503
+ start_time = time .time ()
504
+ outputs = self .model .generate (** generate_params )
505
+ generation_time = time .time () - start_time
506
+ logger .info (f"Generation completed in { generation_time :.2f} seconds" )
507
+
508
+ except RuntimeError as e :
509
+ if "CUDA out of memory" in str (e ):
510
+ # If we run out of memory, clear cache and try again with smaller parameters
511
+ torch .cuda .empty_cache ()
512
+ logger .warning ("CUDA out of memory during generation. Cleared cache and reducing parameters." )
513
+
514
+ # Reduce parameters for memory efficiency
515
+ generate_params ["max_new_tokens" ] = min (generate_params .get ("max_new_tokens" , 512 ), 256 )
516
+
517
+ # Try again with reduced parameters
518
+ outputs = self .model .generate (** generate_params )
519
+ else :
520
+ # For other errors, re-raise
521
+ raise
492
522
493
523
response = self .tokenizer .decode (
494
524
outputs [0 ][len (inputs ["input_ids" ][0 ]):], skip_special_tokens = True )
@@ -551,9 +581,10 @@ def _stream_generate(
551
581
input_ids = inputs ["input_ids" ]
552
582
attention_mask = inputs ["attention_mask" ]
553
583
554
- # Generate multiple tokens at once for efficiency
555
- tokens_to_generate_per_step = 8 # Generate 8 tokens at a time for efficiency
556
-
584
+ # Generate fewer tokens at once for more responsive streaming
585
+ # Using smaller chunks makes it appear more interactive
586
+ tokens_to_generate_per_step = 3 # Reduced from 8 to 3 for more responsive streaming
587
+
557
588
with torch .no_grad ():
558
589
for step in range (0 , max_length , tokens_to_generate_per_step ):
559
590
# Calculate how many tokens to generate in this step
@@ -570,38 +601,78 @@ def _stream_generate(
570
601
"do_sample" : True ,
571
602
"pad_token_id" : self .tokenizer .eos_token_id ,
572
603
"repetition_penalty" : repetition_penalty ,
573
- "early_stopping" : True
604
+ # Remove early_stopping to fix the warning
605
+ "num_beams" : 1 # Explicitly set to 1 to avoid warnings
574
606
}
575
607
576
608
# Use efficient attention if available
577
609
if hasattr (self .model .config , "attn_implementation" ):
578
610
generate_params ["attn_implementation" ] = "flash_attention_2"
579
611
580
- # Generate tokens
581
- outputs = self .model .generate (** generate_params )
582
-
583
- # Get the new tokens (skip the input tokens)
584
- new_tokens = outputs [0 ][len (input_ids [0 ]):]
585
-
586
- # Decode and yield each new token
587
- new_text = self .tokenizer .decode (new_tokens , skip_special_tokens = True )
588
-
589
- # If no new text was generated or it's just whitespace, stop generation
590
- if not new_text or new_text .isspace ():
591
- break
592
-
593
- # Yield the new text
594
- yield new_text
595
-
596
- # Update input_ids and attention_mask for next iteration
597
- input_ids = outputs
598
- attention_mask = torch .ones_like (input_ids )
612
+ try :
613
+ # Generate tokens
614
+ outputs = self .model .generate (** generate_params )
615
+
616
+ # Get the new tokens (skip the input tokens)
617
+ new_tokens = outputs [0 ][len (input_ids [0 ]):]
618
+
619
+ # Decode and yield each new token
620
+ new_text = self .tokenizer .decode (new_tokens , skip_special_tokens = True )
621
+
622
+ # If no new text was generated or it's just whitespace, stop generation
623
+ if not new_text or new_text .isspace ():
624
+ break
625
+
626
+ # Yield the new text
627
+ yield new_text
628
+
629
+ # Update input_ids and attention_mask for next iteration
630
+ input_ids = outputs
631
+ attention_mask = torch .ones_like (input_ids )
632
+
633
+ # Ensure the updated inputs are on the correct device
634
+ if input_ids .device != model_device :
635
+ input_ids = input_ids .to (model_device )
636
+ if attention_mask .device != model_device :
637
+ attention_mask = attention_mask .to (model_device )
638
+
639
+ # Check if we're running out of memory and need to clear cache
640
+ if torch .cuda .is_available ():
641
+ current_mem = torch .cuda .memory_allocated () / (1024 * 1024 * 1024 ) # GB
642
+ total_mem = torch .cuda .get_device_properties (0 ).total_memory / (1024 * 1024 * 1024 ) # GB
643
+ if current_mem > 0.9 * total_mem : # If using >90% of GPU memory
644
+ # Clear cache to avoid OOM
645
+ torch .cuda .empty_cache ()
646
+ logger .info ("Cleared CUDA cache to avoid out of memory error" )
599
647
600
- # Ensure the updated inputs are on the correct device
601
- if input_ids .device != model_device :
602
- input_ids = input_ids .to (model_device )
603
- if attention_mask .device != model_device :
604
- attention_mask = attention_mask .to (model_device )
648
+ except RuntimeError as e :
649
+ if "CUDA out of memory" in str (e ):
650
+ # If we run out of memory, clear cache and try again with smaller batch
651
+ torch .cuda .empty_cache ()
652
+ logger .warning ("CUDA out of memory during streaming. Cleared cache and reducing batch size." )
653
+
654
+ # Reduce tokens per step for the rest of generation
655
+ tokens_to_generate_per_step = 1
656
+ current_tokens_to_generate = 1
657
+
658
+ # Try again with smaller batch
659
+ generate_params ["max_new_tokens" ] = 1
660
+ outputs = self .model .generate (** generate_params )
661
+
662
+ # Continue as before
663
+ new_tokens = outputs [0 ][len (input_ids [0 ]):]
664
+ new_text = self .tokenizer .decode (new_tokens , skip_special_tokens = True )
665
+
666
+ if not new_text or new_text .isspace ():
667
+ break
668
+
669
+ yield new_text
670
+
671
+ input_ids = outputs
672
+ attention_mask = torch .ones_like (input_ids )
673
+ else :
674
+ # For other errors, re-raise
675
+ raise
605
676
606
677
except Exception as e :
607
678
logger .error (f"Streaming generation failed: { str (e )} " )
0 commit comments