Skip to content

Commit a6176ad

Browse files
Datta0shimmyshimmerdanielhanchenjeromekummathew23
authored
Fast Inference with vLLM for VLMs (#2975)
* [WIP] use vLLM for vision language models * Update README.md Editing icon sizes * Update README.md Updating icon sizes * Update README.md (#2885) * MoE kernels AGPLv3 * versioning * Many bug fixes (#2908) * add deepseek v3 * add deepseek r1 base * add deepseek r1 zero * add deepseek distill llama * add deepseek distill models * remove redundant code when constructing model names * add mistral small to registry * rename model registration methods * rename deepseek registration methods * refactor naming for mistral and phi * add global register models * refactor model registration tests for new registry apis * add model search method * remove deprecated registration api * add quant type test * add registry readme * make llama registration more specific * clear registry when executing individual model registration file * more registry readme updates * Update _auto_install.py * Llama4 * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Synthetic data * Update mapper.py * Xet and Synthetic * Update synthetic.py * Update loader.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py --------- Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> * silienty skip falcon h1 import is transformers_version < 4.53.0 (#2912) * Dynamically adjust get_per_token_logps function and patch as well (#2911) * add intel gpu with vllm support (#2903) * [bugs] fix for casual mask (#2868) * fix for casual mask * use un_casual in sdpa * add missing mask * fix for type * Explicitly check if xformers exists for attention (#2889) * Update __init__.py * Update llama.py * if mlp doesn't exist in layer module check for feed_forward name for falcon h1 (#2913) * Move inputs to right devices. (#2919) * Move tensors to right devices * fix multi gpu for non mistral models * multi GPU RoPE for gemma2 * Finish up multi GPU inference * Make multiGPU rope a list * Remove unnecessary transfer to CPU * Remove unnecessary move to CPU * Donot move inputs to device yet will be handled separately in another PR * Move inputs to appropriate decoder device * Make device count global variable * Cleanup RoPE device code * Fixup num_gpu to device count * Cleanup device counts * Use device index for RoPE get_cache * Donot typecast * Use tuple instead of list for tensors. Use device index directly * fixup move to device logic * WIP VLM vLLM * Make vLLM patch a function * Add save and load lora functions * Make fast_inference setup depend on the flag * Improve fast inference patching mechanism * Make vision setting depend on checks in fastbasemodel * Check LoRA and vLLM intercompatibility for vision models * Comment pointing to vLLM LoRA check * Improve lora validation on vLLM * Error out on no vLLM and increase max lora rank * Bug fixes (#3017) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * fix for casual mask (#3011) * [intel] add for intel path for llama.py (#3012) * fix for intel path * remove unuse code * Update unsloth/models/llama.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update llama.py * Fix Gemma 2 (#3024) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * Update _utils.py * Update _utils.py * Update _utils.py * falcon force float32 on sm<75 machines (#3026) * Fix torch compile issues (#3028) * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Debugging only * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Generic efficient GRPO * Update rl_replacements.py * Update rl_replacements.py * Remove debugging * Update rl_replacements.py * Update rl_replacements.py * Update vision.py * Update llama.py * Update rl_replacements.py * versioning * Update _utils.py * Update vision.py * Update mapper.py * Update loader.py * Update mapper.py * Update vision.py * Update loader.py * Update vision.py * Update loader.py * Update _utils.py * Update vision.py * gradient checkpointing * Gemma 3N fixes * Update loader.py * Versioning * Gemma 3N fixes * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Fix setup.py * setup.py * Prints * Update setup.py * Update setup.py * Update setup.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update vision.py * Update vision.py * Update pyproject.toml * Update vision.py * Update _utils.py * Update __init__.py * Update __init__.py * Small fixes * Update vision.py * Update vision.py * versioning * Update __init__.py * Update llama.py * Update rl.py * Update rl.py * Update _utils.py * Update vision.py * Update vision.py * compiler stance * Update _utils.py * Update pyproject.toml * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Revert "Revert "Add Qwen2.5-VL-32B-Instruct mapping to fix quantized model me…" (#2990) This reverts commit 204fc46. * skip_guard_eval_unsafe fix * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update llama.py * Update llama.py * Fix `quantization_method` * versioning * Update _utils.py * Update _utils.py * Update _utils.py * check stride * Cleanup * Update rope_embedding.py * Update gemma2.py * Fix `set_stance` * Update pyproject.toml * Update _utils.py * Fixup patch vllm * Disable mllama * Use variables to decide VLM support * Better attn_impl handling * Patch TF protobuf incompatability * Torch 2.8 (#3186) * Fix mamba * Update loader.py * Update vision.py * Update loader.py * Filter vLLM standby logs (#3131) * filter vLLM standby logs * safeguard standby logger patch * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update loader.py * Add scaler * Update llama.py * Update _utils.py * Versioning * GPT OSS fix * GPT OSS fix * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update llama.py * Versioning * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Upcast norms * Update loader.py * Update vision.py * Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * Update _auto_install.py * Update pyproject.toml * Update rl.py * Protobuf issue * Update pyproject.toml * Fix extras transformers typo in pyproject.toml * Update _utils.py * Bug fixes (#3195) * Fix mamba * Update loader.py * Update vision.py * Update loader.py * Filter vLLM standby logs (#3131) * filter vLLM standby logs * safeguard standby logger patch * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update loader.py * Add scaler * Update llama.py * Update _utils.py * Versioning * GPT OSS fix * GPT OSS fix * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update llama.py * Versioning * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Upcast norms * Update loader.py * Update vision.py * Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py * Update loader.py * UNSLOTH_ENABLE_CCE * Fix * Update loader.py * Update loader.py * Update __init__.py * Update __init__.py * Update __init__.py * Update __init__.py * Import fixes * Update loader.py * Fix aimv2 issue * Update loader.py * Update import_fixes.py * Update import_fixes.py * Update loader.py * Update loader.py * Update loader.py * Upgrade * Update loader.py * Update loader.py * Update loader.py * Update loader.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com> * adallow float32 dtype in FastLanguageModel (#3204) * Update loader.py * Update vision.py * Suppress message and use unsloth sampling params * Use trl sampling params for now * Improve error message * fixup quantized fast inference model name --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: DoubleMathew <mmathew23@gmail.com> Co-authored-by: Lei Zhenyuan <zhenyuan.lei@intel.com> Co-authored-by: parth2510 <parthguptapg7326@gmail.com>
1 parent 3808d80 commit a6176ad

File tree

5 files changed

+222
-63
lines changed

5 files changed

+222
-63
lines changed

unsloth/models/_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@
6969
"patch_fast_lora",
7070
"validate_loftq_config",
7171
"RaiseUninitialized",
72+
"fast_inference_setup",
73+
"patch_peft_fast_inference",
74+
"error_out_no_vllm",
7275
"dequantize_module_weight",
7376
]
7477

@@ -191,6 +194,12 @@ def filter(self, x): return not (self.text in x.getMessage())
191194
del vllm_block_pool_logger
192195
except:
193196
pass
197+
try:
198+
from vllm.lora.models import logger as vllm_lora_model_logger
199+
vllm_lora_model_logger.addFilter(HideLoggingMessage("Regarding multimodal models, vLLM currently only supports adding"))
200+
del vllm_lora_model_logger
201+
except:
202+
pass
194203
pass
195204

196205
# The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
@@ -1584,6 +1593,45 @@ def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, m
15841593

15851594
return loftq_config
15861595

1596+
def fast_inference_setup(model_name, model_config):
1597+
fast_inference = True
1598+
if not is_vLLM_available():
1599+
logger.warning_once("Unsloth: vLLM is not installed! Will use Unsloth inference!")
1600+
fast_inference = False
1601+
pass
1602+
from unsloth_zoo.vllm_utils import (
1603+
patch_vllm,
1604+
vllm_dynamic_quant_supported,
1605+
)
1606+
patch_vllm()
1607+
if model_name.endswith("unsloth-bnb-4bit"):
1608+
if not vllm_dynamic_quant_supported(model_name, model_config):
1609+
# Instead use -bnb-4bit variant
1610+
logger.warning_once(
1611+
f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\
1612+
f"we do not yet support fast inference for {model_name}"
1613+
)
1614+
model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit"
1615+
pass
1616+
pass
1617+
return fast_inference, model_name
1618+
1619+
def patch_peft_fast_inference(model):
1620+
vllm_engine = getattr(model.model, "vllm_engine", None)
1621+
if vllm_engine is not None:
1622+
model.vllm_engine = model.model.vllm_engine
1623+
model.fast_generate = model.model.fast_generate
1624+
model.fast_generate_batches = model.model.fast_generate_batches
1625+
1626+
# Also saving and loading LoRA
1627+
from unsloth_zoo.vllm_utils import save_lora, load_lora
1628+
model.save_lora = functools.partial(save_lora, model)
1629+
model.load_lora = functools.partial(load_lora, model)
1630+
pass
1631+
1632+
def error_out_no_vllm(*args, **kwargs):
1633+
raise NotImplementedError("Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead")
1634+
15871635

15881636
def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.Module:
15891637
"""

unsloth/models/llama.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,7 +2574,7 @@ def get_peft_model(
25742574
raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.")
25752575
pass
25762576

2577-
#d oes not get lora yet, so get name from model, not base model
2577+
# Does not get lora yet, so get name from model, not base model
25782578
is_classification = "Classification" in str(type(model))
25792579

25802580
arguments = dict(
@@ -2694,17 +2694,7 @@ def get_peft_model(
26942694
clean_gpu_cache()
26952695
pass
26962696

2697-
# Patch for fast inference
2698-
if vllm_engine is not None:
2699-
model.vllm_engine = vllm_engine
2700-
model.fast_generate = vllm_fast_generate
2701-
model.fast_generate_batches = vllm_fast_generate_batches
2702-
2703-
# Also saving and loading LoRA
2704-
from unsloth_zoo.vllm_utils import save_lora, load_lora
2705-
model.save_lora = functools.partial(save_lora, model)
2706-
model.load_lora = functools.partial(load_lora, model)
2707-
pass
2697+
patch_peft_fast_inference(model)
27082698

27092699
# Add for_inference and for_training
27102700
model.for_training = functools.partial(FastLlamaModel.for_training, model)
@@ -2916,18 +2906,7 @@ def patch_peft_model(
29162906
clean_gpu_cache()
29172907
pass
29182908

2919-
# Patch for fast inference
2920-
vllm_engine = getattr(model.model, "vllm_engine", None)
2921-
if vllm_engine is not None:
2922-
model.vllm_engine = model.model.vllm_engine
2923-
model.fast_generate = model.model.fast_generate
2924-
model.fast_generate_batches = model.model.fast_generate_batches
2925-
2926-
# Also saving and loading LoRA
2927-
from unsloth_zoo.vllm_utils import save_lora, load_lora
2928-
model.save_lora = functools.partial(save_lora, model)
2929-
model.load_lora = functools.partial(load_lora, model)
2930-
pass
2909+
patch_peft_fast_inference(model)
29312910

29322911
# Add for_inference and for_training
29332912
model.for_training = functools.partial(FastLlamaModel.for_training, model)

unsloth/models/loader.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
patch_compiled_autograd,
7979
process_vision_info,
8080
unsloth_compile_transformers,
81+
fast_inference_setup,
8182
)
8283

8384
global FORCE_FLOAT32
@@ -142,6 +143,15 @@ def from_pretrained(
142143
return_logits = False, # Return logits
143144
fullgraph = True, # No graph breaks
144145
use_exact_model_name = use_exact_model_name,
146+
147+
# Pass vLLM/inference parameters
148+
fast_inference = fast_inference,
149+
gpu_memory_utilization = gpu_memory_utilization,
150+
float8_kv_cache = float8_kv_cache,
151+
random_state = random_state,
152+
max_lora_rank = max_lora_rank,
153+
disable_log_stats = disable_log_stats,
154+
145155
qat_scheme = qat_scheme,
146156
*args, **kwargs,
147157
)
@@ -370,6 +380,15 @@ def from_pretrained(
370380
return_logits = False, # Return logits
371381
fullgraph = True, # No graph breaks
372382
use_exact_model_name = use_exact_model_name,
383+
384+
# Pass vLLM/inference parameters
385+
fast_inference = fast_inference,
386+
gpu_memory_utilization = gpu_memory_utilization,
387+
float8_kv_cache = float8_kv_cache,
388+
random_state = random_state,
389+
max_lora_rank = max_lora_rank,
390+
disable_log_stats = disable_log_stats,
391+
373392
*args, **kwargs,
374393
)
375394
pass
@@ -388,26 +407,7 @@ def from_pretrained(
388407
pass
389408

390409
if fast_inference:
391-
if not is_vLLM_available():
392-
print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
393-
fast_inference = False
394-
pass
395-
from unsloth_zoo.vllm_utils import (
396-
patch_vllm,
397-
vllm_dynamic_quant_supported,
398-
)
399-
patch_vllm()
400-
if model_name.endswith("unsloth-bnb-4bit"):
401-
if not vllm_dynamic_quant_supported(model_name, model_config):
402-
# Instead use -bnb-4bit variant
403-
print(
404-
f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\
405-
f"we do not yet support fast inference for {model_name}"
406-
)
407-
model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit"
408-
pass
409-
pass
410-
pass
410+
fast_inference, model_name = fast_inference_setup(model_name, model_config)
411411

412412
model, tokenizer = dispatch_model.from_pretrained(
413413
model_name = model_name,
@@ -530,6 +530,15 @@ def from_pretrained(
530530
whisper_language = None,
531531
whisper_task = None,
532532
unsloth_force_compile = False,
533+
534+
# Add the missing vLLM/inference parameters
535+
fast_inference = False, # uses vLLM
536+
gpu_memory_utilization = 0.5,
537+
float8_kv_cache = False,
538+
random_state = 3407,
539+
max_lora_rank = 64,
540+
disable_log_stats = True,
541+
533542
qat_scheme = None,
534543
*args, **kwargs,
535544
):
@@ -884,6 +893,15 @@ def from_pretrained(
884893
supports_sdpa = supports_sdpa,
885894
whisper_language = whisper_language,
886895
whisper_task = whisper_task,
896+
897+
# Pass vLLM/inference parameters
898+
fast_inference = fast_inference,
899+
gpu_memory_utilization = gpu_memory_utilization,
900+
float8_kv_cache = float8_kv_cache,
901+
random_state = random_state,
902+
max_lora_rank = max_lora_rank,
903+
disable_log_stats = disable_log_stats,
904+
887905
*args, **kwargs,
888906
)
889907

unsloth/models/rl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
550550
pass
551551

552552
# Warn on too large or too small learning rate
553-
if " learning_rate" in call_args:
553+
if "learning_rate" in call_args:
554554
learning_rate_check = \
555555
"if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\
556556
"Consider increasing it, otherwise gradient updates will be close to 0!')\n"\
@@ -937,6 +937,13 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
937937
r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))",
938938
source
939939
)
940+
# Prefer using unsloth's sampling params and fallback to trl's if not found
941+
# We'll enable this later separately when combining both this and GRPOConfig params
942+
# source = re.sub(
943+
# r"sampling_params\s*=\s*sampling_params",
944+
# r"sampling_params = getattr(self.args, 'vllm_sampling_params', sampling_params)",
945+
# source
946+
# )
940947

941948
# Skip if no changes done
942949
if source == original_source: continue

0 commit comments

Comments
 (0)