7878 patch_compiled_autograd ,
7979 process_vision_info ,
8080 unsloth_compile_transformers ,
81+ fast_inference_setup ,
8182)
8283
8384global FORCE_FLOAT32
@@ -142,6 +143,15 @@ def from_pretrained(
142143 return_logits = False , # Return logits
143144 fullgraph = True , # No graph breaks
144145 use_exact_model_name = use_exact_model_name ,
146+
147+ # Pass vLLM/inference parameters
148+ fast_inference = fast_inference ,
149+ gpu_memory_utilization = gpu_memory_utilization ,
150+ float8_kv_cache = float8_kv_cache ,
151+ random_state = random_state ,
152+ max_lora_rank = max_lora_rank ,
153+ disable_log_stats = disable_log_stats ,
154+
145155 qat_scheme = qat_scheme ,
146156 * args , ** kwargs ,
147157 )
@@ -370,6 +380,15 @@ def from_pretrained(
370380 return_logits = False , # Return logits
371381 fullgraph = True , # No graph breaks
372382 use_exact_model_name = use_exact_model_name ,
383+
384+ # Pass vLLM/inference parameters
385+ fast_inference = fast_inference ,
386+ gpu_memory_utilization = gpu_memory_utilization ,
387+ float8_kv_cache = float8_kv_cache ,
388+ random_state = random_state ,
389+ max_lora_rank = max_lora_rank ,
390+ disable_log_stats = disable_log_stats ,
391+
373392 * args , ** kwargs ,
374393 )
375394 pass
@@ -388,26 +407,7 @@ def from_pretrained(
388407 pass
389408
390409 if fast_inference :
391- if not is_vLLM_available ():
392- print ("Unsloth: vLLM is not installed! Will use Unsloth inference!" )
393- fast_inference = False
394- pass
395- from unsloth_zoo .vllm_utils import (
396- patch_vllm ,
397- vllm_dynamic_quant_supported ,
398- )
399- patch_vllm ()
400- if model_name .endswith ("unsloth-bnb-4bit" ):
401- if not vllm_dynamic_quant_supported (model_name , model_config ):
402- # Instead use -bnb-4bit variant
403- print (
404- f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n " \
405- f"we do not yet support fast inference for { model_name } "
406- )
407- model_name = model_name [:- len ("unsloth-bnb-4bit" )] + "bnb-4bit"
408- pass
409- pass
410- pass
410+ fast_inference , model_name = fast_inference_setup (model_name , model_config )
411411
412412 model , tokenizer = dispatch_model .from_pretrained (
413413 model_name = model_name ,
@@ -530,6 +530,15 @@ def from_pretrained(
530530 whisper_language = None ,
531531 whisper_task = None ,
532532 unsloth_force_compile = False ,
533+
534+ # Add the missing vLLM/inference parameters
535+ fast_inference = False , # uses vLLM
536+ gpu_memory_utilization = 0.5 ,
537+ float8_kv_cache = False ,
538+ random_state = 3407 ,
539+ max_lora_rank = 64 ,
540+ disable_log_stats = True ,
541+
533542 qat_scheme = None ,
534543 * args , ** kwargs ,
535544 ):
@@ -884,6 +893,15 @@ def from_pretrained(
884893 supports_sdpa = supports_sdpa ,
885894 whisper_language = whisper_language ,
886895 whisper_task = whisper_task ,
896+
897+ # Pass vLLM/inference parameters
898+ fast_inference = fast_inference ,
899+ gpu_memory_utilization = gpu_memory_utilization ,
900+ float8_kv_cache = float8_kv_cache ,
901+ random_state = random_state ,
902+ max_lora_rank = max_lora_rank ,
903+ disable_log_stats = disable_log_stats ,
904+
887905 * args , ** kwargs ,
888906 )
889907
0 commit comments