diff --git a/examples/alora/alora_offline_example.py b/examples/alora/alora_offline_example.py new file mode 100644 index 00000000000..4133ac341d2 --- /dev/null +++ b/examples/alora/alora_offline_example.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import time + +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" + +ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" +invocation_string = "<|start_of_role|>certainty<|end_of_role|>" + +os.environ["VLLM_USE_V1"] = "1" + +# download your LoRA adapter to ~/.cache/huggingface/… +alora_path = snapshot_download(repo_id=ALORA_NAME) + +print(alora_path) +####################################### + + +llm = LLM( + model=BASE_NAME, + enable_lora=True, + enable_activated_lora=True, + dtype=torch.bfloat16, + max_lora_rank=64, +) + +prompts = [ + ( + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n" + "<|start_of_role|>assistant<|end_of_role|>" + ), +] + +sampling_params = SamplingParams(temperature=0, max_tokens=600) + +outputsBase = llm.generate( + prompts, + sampling_params, + use_tqdm=False, +) +generated_text = [] +for output in outputsBase: + prompt = output.prompt + generated_text += [output.outputs[0].text] + print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") + +prompts_alora = [ + x + y + "<|end_of_text|>\n" + invocation_string + for x, y in zip(prompts, generated_text) +] + +sampling_params = SamplingParams(temperature=0, max_tokens=10) + +t0 = time.time() +outputs = llm.generate( + prompts_alora, + sampling_params, + lora_request=LoRARequest("UQ_adapter", 1, alora_path), + use_tqdm=False, +) +t = time.time() - t0 +print(f"Time: {t}") + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index 54c7a497b26..19bb81fd0f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2876,6 +2876,8 @@ class LoRAConfig: allowed.""" bias_enabled: bool = False """Enable bias for LoRA adapters.""" + activated_lora_enabled: bool = False + """Enable Activated LoRA.""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7a88e3269a5..82f33acc4fe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -355,6 +355,7 @@ class EngineArgs: # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled + enable_activated_lora: bool = LoRAConfig.activated_lora_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras @@ -733,6 +734,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If True, enable handling of LoRA adapters.") lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--enable-activated-lora", + **lora_kwargs["activated_lora_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) @@ -1191,6 +1194,7 @@ def create_engine_config( lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, + activated_lora_enabled=self.enable_activated_lora, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..a9870dae6e6 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,6 +26,11 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class ALoRAMetadata: + mask1d: torch.Tensor + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor @@ -94,6 +99,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + alora_metadata: Optional[ALoRAMetadata] = None skip_cuda_graphs: bool = False @@ -116,6 +122,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + alora_metadata: Optional[ALoRAMetadata] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -140,6 +147,7 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, + alora_metadata=alora_metadata, skip_cuda_graphs=skip_cuda_graphs, ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3d0c5831750..145d3bbcc5b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -19,6 +19,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.utils import divide +from vllm.forward_context import get_forward_context # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -874,7 +875,8 @@ def can_replace_layer( model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) + and len(packed_modules_list) == 3 + and not lora_config.activated_lora_enabled) #TODO: Implement this @@ -1283,3 +1285,50 @@ def can_replace_layer( def extra_repr(self) -> str: return self.base_layer.extra_repr() + + +class MergedQKVParallelLinearWithActivatedLoRA(MergedQKVParallelLinearWithLoRA + ): + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + # Extract aLoRA batch metadata from forward context + alora_metadata = get_forward_context().alora_metadata + + mask1d = alora_metadata.mask1d + mask2d = mask1d.unsqueeze(1).to(output.dtype) + + # Clone base layer output before running LoRA + orig_out = output.clone() + + # Apply LoRA in‐place on `output`: + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + # Apply alora mask + final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) + return final_output + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3 + and lora_config.activated_lora_enabled) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index a20d73f0f72..edd7be16fdf 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -37,6 +37,8 @@ class PEFTHelper: use_dora: bool = field(default=False) # long context lora field context_length: int = field(default=0) + # Invocation string for Activated LoRA (aLoRA, see: https://arxiv.org/abs/2504.12397) + invocation_string: Optional[str] = field(default=None) # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5bbba7830c1..64762f15ec2 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,6 +33,7 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None + invocation_start: Optional[int] = None def __post_init__(self): if self.lora_local_path: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ee196e3f689..f5ad741fdcc 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -25,6 +25,7 @@ LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithActivatedLoRA, MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, ReplicatedLinearWithLoRA, @@ -44,6 +45,7 @@ MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithActivatedLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb18..0648e8e6a19 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -229,6 +230,16 @@ def __init__( ): super().__init__() + vllm_config = get_current_vllm_config() + if (vllm_config.lora_config + and vllm_config.lora_config.activated_lora_enabled): + # lets torch.compile know that forward_context needs to be + # considered as an input to the layer (copied from attention) + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # Keep input parameters self.input_size = input_size self.output_size = output_size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433f..b5a0af9cb76 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -473,6 +473,12 @@ def hash_request_tokens(hash_function: Any, block_size: int, # MM and LoRA requests need extra keys for block-hash computation. req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) + # Respect a-LoRA behaviour + if (request.lora_request is not None + and request.lora_request.invocation_start is not None + and end <= request.lora_request.invocation_start): + # cache is equivalent to base model cache + req_extra_keys = None block_hash = hash_block_tokens(hash_function, parent_block_hash_value, block_token_ids, req_extra_keys) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b00f1444c7b..c783ced22bd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -9,6 +9,7 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) @@ -328,6 +329,44 @@ def process_inputs( else: sorted_mm_inputs = orig_sorted_mm_inputs + # Tokenize aLoRA invocation sequence if applicable. + if (self.lora_config and self.lora_config.activated_lora_enabled + and lora_request is not None): + + text_config = self.model_config.hf_config.get_text_config() + + peft_helper = PEFTHelper.from_local_dir( + lora_request.lora_path, text_config.max_position_embeddings, + lora_request.tensorizer_config_dict) + + if peft_helper.invocation_string is not None: + + invocation_tokens = self.input_preprocessor._tokenize_prompt( + peft_helper.invocation_string, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs) + + invocation_start = -1 + n = len(invocation_tokens) + token_ids = decoder_inputs["prompt_token_ids"] + + if n > 0 and len(token_ids) >= n: + # scan backward for the last match + # (faster than full forward scan+max) + for idx in range(len(token_ids) - n, -1, -1): + if token_ids[idx:idx + n] == invocation_tokens: + # weights activated 1 token after start + invocation_start = idx + 1 + break + + if invocation_start == -1: + raise ValueError( + "Invocation sequence not found in prompt " + f"for request '{request_id}'. aLoRA models require the " + "invocation tokens to be present in the input.") + + lora_request.invocation_start = invocation_start + return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f96fb64342c..917931eacbe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,8 +27,8 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import (ALoRAMetadata, DPMetadata, + get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -238,6 +238,11 @@ def __init__( dtype=torch.int64, device=self.device) + if self.lora_config and self.lora_config.activated_lora_enabled: + self.mask1d = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -568,8 +573,9 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + Optional[ALoRAMetadata], np.ndarray]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -755,8 +761,31 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) + # Compute a-LoRA metadata + if self.lora_config and self.lora_config.activated_lora_enabled: + invocation_start = np.empty(shape=(num_reqs, ), dtype=int) + for req_id in self.input_batch.req_ids: + req_index = self.input_batch.req_id_to_index[req_id] + cached_lora_request = self.requests[req_id].lora_request + if (cached_lora_request is not None + and cached_lora_request.invocation_start is not None): + invocation_start[ + req_index] = cached_lora_request.invocation_start + else: + invocation_start[req_index] = len( + self.requests[req_id].prompt_token_ids) + mask1d_cpu = torch.tensor(positions_np + < invocation_start[req_indices], + dtype=torch.bool, + device="cpu") + mask1d = self.mask1d[:total_num_scheduled_tokens] + mask1d.copy_(mask1d_cpu, non_blocking=True) + alora_metadata = ALoRAMetadata(mask1d=mask1d) + else: + alora_metadata = None + return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, alora_metadata, num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1265,7 +1294,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, + spec_decode_metadata, alora_metadata, num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph @@ -1344,6 +1373,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, + alora_metadata=alora_metadata, ): self.maybe_setup_kv_connector(scheduler_output) @@ -1936,11 +1966,19 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + alora_metadata = None + if self.lora_config and self.lora_config.activated_lora_enabled: + mask1d = self.mask1d[:num_tokens] + alora_metadata = ALoRAMetadata(mask1d=mask1d) + # needed to avoid guard failures + torch._dynamo.mark_dynamic(alora_metadata.mask1d, 0) + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + alora_metadata=alora_metadata): outputs = model( input_ids=input_ids, positions=positions,