From 8a9610e62f534c1f17afa121122f9fd93aaa01bd Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 16 Jun 2025 20:28:05 +0000 Subject: [PATCH 01/19] Initial working implementation of a-LoRA. Co-authored-by: Greenewald Co-authored-by: Allison Li Signed-off-by: Thomas Parnell --- examples/alora/alora_server_testing.py | 67 ++++++++++++++++ examples/alora/alora_server_testing.sh | 46 +++++++++++ examples/alora/new_alora_testing.py | 74 +++++++++++++++++ vllm/envs.py | 5 ++ vllm/forward_context.py | 9 +++ vllm/lora/layers.py | 47 +++++++++-- vllm/lora/request.py | 2 + vllm/model_executor/layers/linear.py | 7 ++ vllm/v1/core/kv_cache_utils.py | 16 ++++ vllm/v1/core/sched/scheduler.py | 14 ++++ vllm/v1/engine/processor.py | 20 +++++ vllm/v1/worker/gpu_model_runner.py | 105 ++++++++++++++++++++++++- 12 files changed, 401 insertions(+), 11 deletions(-) create mode 100644 examples/alora/alora_server_testing.py create mode 100644 examples/alora/alora_server_testing.sh create mode 100644 examples/alora/new_alora_testing.py diff --git a/examples/alora/alora_server_testing.py b/examples/alora/alora_server_testing.py new file mode 100644 index 00000000000..e9616600a54 --- /dev/null +++ b/examples/alora/alora_server_testing.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# After starting server using "vllm serve --enable_lora --lora_modules..." + +import time + +from openai import OpenAI + +model_id = "ibm-granite/granite-3.2-8b-instruct" + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" +ALORA_NAME = "new_alora" # "ibm-granite/granite-3.2-8b-alora-uncertainty" +invocation_string = "<|start_of_role|>certainty<|end_of_role|>" + +################################################################### +prompts = [ + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", + "What is MIT?", + ( + "<|start_of_role|>user<|end_of_role|>What is the capital of " + "Massachusetts?<|end_of_text|>\n" + ), + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", + ( + "<|start_of_role|>user<|end_of_role|>What is the capital of " + "Massachusetts?<|end_of_text|>\n" + ), + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", +] + +# Base model call +outputs_base = client.completions.create( + model=BASE_NAME, prompt=prompts, temperature=0, max_tokens=600 +) + +choices = outputs_base.choices +generated_text = [] +for i in range(len(prompts)): + prompt = prompts[i] + + generated_text += [outputs_base.choices[i].text] + print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") + +prompts_alora = [ + x + y + "<|end_of_text|>\n" + invocation_string + for x, y in zip(prompts, generated_text) +] + +# Base model with aLoRA call +t0 = time.time() +alora_outputs = client.completions.create( + model=ALORA_NAME, prompt=prompts_alora, temperature=0, max_tokens=10 +) +t = time.time() - t0 +print(f"Time: {t}") +for i in range(len(prompts_alora)): + prompt = prompts_alora[i] + generated_text = alora_outputs.choices[i].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/alora/alora_server_testing.sh b/examples/alora/alora_server_testing.sh new file mode 100644 index 00000000000..49eb9c5612f --- /dev/null +++ b/examples/alora/alora_server_testing.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# More documentation: https://docs.vllm.ai/en/v0.8.3/serving/openai_compatible_server.html#vllm-serve +export VLLM_USE_V1="1" +# Specify base model (and optionally loras) to load in when starting the server. +vllm serve ibm-granite/granite-3.2-8b-instruct \ + --enable-lora \ + --lora-modules '{"name": "new_alora", "path": "/proj/dmfexp/statllm/users/kgreenewald/.cache/huggingface/models/hub/models--ibm-granite--granite-3.2-8b-alora-uncertainty/snapshots/6109ad88201426003e696d023ec67c19e7f3d444", "base_model_name": "ibm-granite/granite-3.2-8b-instruct"}' \ + --dtype bfloat16 \ + --max-lora-rank 64 \ + --enable-prefix-caching +#--no-enable-prefix-caching +# Check that the lora model is listed along with other models. +#curl localhost:8000/v1/models | jq . + +########################################### + +# A second option is to enable dynamic adapter loading instead of at start-up. +#export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True + +#curl -X POST http://localhost:8000/v1/load_lora_adapter \ +#-H "Content-Type: application/json" \ +#-d '{ +# "lora_name": "new_alora", +# "lora_path": "/path/to/new_alora" +#}' +# Should return "200 OK - Success: LoRA adapter 'new_alora' added successfully" + +# Example of dynamically unloading an adapter. +# curl -X POST http://localhost:8000/v1/unload_lora_adapter \ +# -H "Content-Type: application/json" \ +# -d '{ +# "lora_name": "new_alora" +# }' + +########################################### + +# Send a request using the new aLoRA +#curl http://localhost:8000/v1/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "new_alora", +# "prompt": ""What is MIT?"", +# "max_tokens": 600, +# "temperature": 0 +# }' | jq diff --git a/examples/alora/new_alora_testing.py b/examples/alora/new_alora_testing.py new file mode 100644 index 00000000000..5d3908ff8f2 --- /dev/null +++ b/examples/alora/new_alora_testing.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import time + +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" +ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" +invocation_string = "<|start_of_role|>certainty<|end_of_role|>" + +os.environ["VLLM_USE_V1"] = "1" +os.environ["VLLM_V1_USE_DEMO_LOGGING"] = "1" + +# download your LoRA adapter to ~/.cache/huggingface/… +alora_path = snapshot_download(repo_id=ALORA_NAME) + +print(alora_path) +####################################### + + +llm = LLM( + model=BASE_NAME, + enable_lora=True, + enforce_eager=True, + dtype=torch.bfloat16, + enable_prefix_caching=True, # enable APC + max_lora_rank=64, + enable_chunked_prefill=False, +) + +prompts = [ + ( + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n" + "<|start_of_role|>assistant<|end_of_role|>" + ), +] + +sampling_params = SamplingParams(temperature=0, max_tokens=600) + +outputsBase = llm.generate( + prompts, + sampling_params, +) +generated_text = [] +for output in outputsBase: + prompt = output.prompt + generated_text += [output.outputs[0].text] + print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") + +prompts_alora = [ + x + y + "<|end_of_text|>\n" + invocation_string + for x, y in zip(prompts, generated_text) +] + +sampling_params = SamplingParams(temperature=0, max_tokens=10) + +t0 = time.time() +outputs = llm.generate( + prompts_alora, + sampling_params, + lora_request=LoRARequest("UQ_adapter", 1, alora_path), +) +t = time.time() - t0 +print(f"Time: {t}") + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee..92d985c79fc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -120,6 +120,7 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_V1_USE_DEMO_LOGGING: bool = True VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 @@ -835,6 +836,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + # Useful for demo + "VLLM_V1_USE_DEMO_LOGGING": + lambda: os.environ.get("VLLM_V1_USE_DEMO_LOGGING", "0") == "1", + # If set, allow insecure serialization using pickle. # This is useful for environments where it is deemed safe to use the # insecure method and it is needed for some reason. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..144e3cc4a1d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,6 +26,12 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class ALoRAMetadata: + k_offsets: torch.Tensor + query_start_locs: list[int] + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor @@ -94,6 +100,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + alora_metadata: Optional[ALoRAMetadata] = None skip_cuda_graphs: bool = False @@ -116,6 +123,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + alora_metadata: Optional[ALoRAMetadata] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -140,6 +148,7 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, + alora_metadata=alora_metadata, skip_cuda_graphs=skip_cuda_graphs, ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3d0c5831750..db8a881b089 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -19,6 +19,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.utils import divide +from vllm.forward_context import get_forward_context # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -418,14 +419,44 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) - if not current_platform.can_update_inplace(): - output = lora_output - - return output + # Extract aLoRA batch metadata from forward context + alora_metadata = get_forward_context().alora_metadata + k_offsets = alora_metadata.k_offsets + query_start_locs = alora_metadata.query_start_locs + + # Build the 1D “save‐prefix” mask: + T = output.size(0) # total tokens + starts = query_start_locs[:-1] # starts and end index of each request + ends = query_start_locs[1:] + lengths = ends - starts # request lengths + kept_lens = lengths - k_offsets + kept_lens = torch.clamp( + kept_lens, + min=0) # portion of request to keep as base model weights + + device = output.device + # Create the alora mask + delta = torch.zeros(T + 1, device=device, dtype=output.dtype) + ends_for_scatter = starts + kept_lens + pos_vals = kept_lens.sign().to(output.dtype) + neg_vals = -pos_vals + delta.scatter_add_(0, starts, pos_vals) + delta.scatter_add_(0, ends_for_scatter, neg_vals) + cums = torch.cumsum(delta[:-1], dim=0) + mask1d = cums > 0 # shape [T], bool + mask2d = mask1d.unsqueeze(1).to(output.dtype) + + # Clone base layer output before running LoRA + orig_out = output.clone() + + # Apply LoRA in‐place on `output`: + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + # Apply alora mask + final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) + return final_output @property def weight(self) -> torch.Tensor: diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5bbba7830c1..a566e50d605 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,6 +33,8 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None + invocation_tokens: Optional[list[int]] = None + k_offset: Optional[int] = None def __post_init__(self): if self.lora_local_path: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb18..f91e05fddcf 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -229,6 +230,12 @@ def __init__( ): super().__init__() + # tpa -- find out why this is needed + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # Keep input parameters self.input_size = input_size self.output_size = output_size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433f..7eff333a124 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -457,6 +457,20 @@ def hash_request_tokens(hash_function: Any, block_size: int, token_ids = request.all_token_ids req_need_extra_keys = need_extra_keys(request) + if (request.lora_request is not None + and request.lora_request.invocation_tokens is not None): + use_alora = True + invocation_tokens = request.lora_request.invocation_tokens + # scan backward for the last match (faster than full forward scan+max) + invocation_start = -1 + n = len(invocation_tokens) + for idx in range(len(token_ids) - n, -1, -1): + if token_ids[idx:idx + n] == invocation_tokens: + # weights activated 1 token after start + invocation_start = idx + 1 + break + else: + use_alora = False req_extra_keys = None curr_mm_idx = 0 @@ -473,6 +487,8 @@ def hash_request_tokens(hash_function: Any, block_size: int, # MM and LoRA requests need extra keys for block-hash computation. req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) + if use_alora and end <= invocation_start: + req_extra_keys = None # cache is equivalent to base model cache block_hash = hash_block_tokens(hash_function, parent_block_hash_value, block_token_ids, req_extra_keys) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d2274ab6a4..4bb2bbd6e2c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -8,6 +8,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -211,6 +212,13 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, self.max_model_len - request.num_computed_tokens) + if envs.VLLM_V1_USE_DEMO_LOGGING and num_new_tokens > 1: + logger.info("request_id: %s", request.request_id) + logger.info("num_tokens: %d", request.num_tokens) + logger.info("num_computed_tokens: %d", + request.num_computed_tokens) + logger.info("num_new_tokens: %d", num_new_tokens) + # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget @@ -416,6 +424,12 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break + if envs.VLLM_V1_USE_DEMO_LOGGING: + logger.info("request_id: %s", request.request_id) + logger.info("num_tokens: %d", request.num_tokens) + logger.info("num_computed_tokens: %d", num_computed_tokens) + logger.info("num_new_tokens: %d", num_new_tokens) + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e28879d4046..d6b0c28f16f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os import time from collections.abc import Mapping, Sequence from typing import Any, Literal, Optional, Union @@ -10,6 +12,7 @@ from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.multimodal.inputs import PlaceholderRange @@ -324,6 +327,23 @@ def process_inputs( else: sorted_mm_inputs = orig_sorted_mm_inputs + # Tokenize aLoRA invocation sequence if applicable. + if lora_request is not None: + + # Load in adapter config file + lora_path = get_adapter_absolute_path(lora_request.lora_path) + lora_config_path = os.path.join(lora_path, "adapter_config.json") + with open(lora_config_path) as f: + config = json.load(f) + + if "invocation_string" in config: # check if aLoRA + invocation_tokens = self.input_preprocessor._tokenize_prompt( + config["invocation_string"], + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs) + # Make it an aLoRA request + # (in future, this will happen upstream) + lora_request.invocation_tokens = invocation_tokens return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 558325fa034..c0cdfdc9048 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,8 +26,8 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import (ALoRAMetadata, DPMetadata, + get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader @@ -746,6 +746,83 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata) + def _extract_offsets( + self, + scheduler_output: "SchedulerOutput", + ) -> ALoRAMetadata: + """ + Extract k_offsets for each new scheduled req that is called with aLoRA. + Prepare aLoRA metadata for model execution. + """ + + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + print(new_req_data.lora_request) + if (new_req_data.lora_request is not None and + new_req_data.lora_request.invocation_tokens is not None): + tokens = new_req_data.lora_request.invocation_tokens + prompt_ids = new_req_data.prompt_token_ids + n = len(tokens) + k_offset = -1 + # only bother if there actually are invocation tokens + if n > 0 and len(prompt_ids) >= n: + # scan backward for the last match + # (faster than full forward scan+max) + for idx in range(len(prompt_ids) - n, -1, -1): + if prompt_ids[idx:idx + n] == tokens: + # offset = number of tokens from the start + # of that match to the end of the prompt + k_offset = len(prompt_ids) - idx - 1 + break + if k_offset == -1: + raise ValueError( + "Invocation sequence not found in prompt " + f"for request '{req_id}'. aLoRA models require the " + "invocation tokens to be present in the input.") + + cached_lora_request = self.requests[req_id].lora_request + assert cached_lora_request is not None + cached_lora_request.k_offset = k_offset + + # Fill in k_offsets based on the `scheduled_new_reqs` and + # `scheduled_cached_reqs` within the SchedulerOutput. + num_seqs = len(self.query_start_loc_np.tolist()) - 1 + k_offsets = [1] * (num_seqs) + + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + req_index = self.input_batch.req_id_to_index[req_id] + cached_lora_request = self.requests[req_id].lora_request + if (cached_lora_request is not None + and cached_lora_request.k_offset is not None): + k_offsets[req_index] = cached_lora_request.k_offset + else: + k_offsets[req_index] = len( + self.requests[req_id].prompt_token_ids) + + for cached_req_data in scheduler_output.scheduled_cached_reqs: + req_id = cached_req_data.req_id + req_index = self.input_batch.req_id_to_index[req_id] + cached_lora_request = self.requests[req_id].lora_request + if (cached_lora_request is not None + and cached_lora_request.k_offset is not None): + k_offsets[req_index] = cached_lora_request.k_offset + else: + k_offsets[req_index] = len( + self.requests[req_id].prompt_token_ids) + + query_locs = torch.tensor(self.query_start_loc_np.tolist(), + device=self.device) + + if len(query_locs) > self.input_batch.num_reqs + 1: + query_locs[self.input_batch.num_reqs + 1:] = 0 + + alora_metadata = ALoRAMetadata(k_offsets=torch.tensor( + k_offsets, device=self.device), + query_start_locs=query_locs) + + return alora_metadata + def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, @@ -1209,6 +1286,11 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) + + # tpa - let's do this in prepare input> + # Extract the aLoRA offsets if applicable. + alora_metadata = self._extract_offsets(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1286,6 +1368,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, + alora_metadata=alora_metadata, ): self.maybe_setup_kv_connector(scheduler_output) @@ -1872,11 +1955,27 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + # Prepare dummy ALoRAMetadata + dummy_k_offsets = torch.tensor([1] * max_num_reqs, + device=self.device) + dummy_cu_num_tokens = np.cumsum(num_scheduled_tokens) + dummy_query_start_loc = [0] * (max_num_reqs + 1) + dummy_query_start_loc[0] = 0 + dummy_query_start_loc[1:num_reqs + 1] = dummy_cu_num_tokens + dummy_query_start_loc = torch.tensor(dummy_query_start_loc, + device=self.device) + dummy_alora_metadata = ALoRAMetadata( + k_offsets=dummy_k_offsets, + query_start_locs=dummy_query_start_loc, + ) + #num_reqs=num_reqs,) + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + alora_metadata=dummy_alora_metadata): outputs = model( input_ids=input_ids, positions=positions, From a68e70b9a109e53bf7b20c0ecac96b04944dc194 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 17 Jun 2025 07:39:42 +0000 Subject: [PATCH 02/19] Fix type hint for query_start_locs Signed-off-by: Thomas Parnell --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 144e3cc4a1d..cf8447a5a25 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -29,7 +29,7 @@ @dataclass class ALoRAMetadata: k_offsets: torch.Tensor - query_start_locs: list[int] + query_start_locs: torch.Tensor @dataclass From b254fb7ee12825b70d1e0a60c641e1dd96d65a70 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 04:33:37 +0000 Subject: [PATCH 03/19] vllm/model_executor/layers/linear.py: add comment on torch.compile Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f91e05fddcf..ddf55bc49ad 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -230,7 +230,8 @@ def __init__( ): super().__init__() - # tpa -- find out why this is needed + # lets torch.compile know that forward_context needs to be + # considered as an input to the layer (copied from attention) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") From 3897b1b630e5913dbcbc89571d42412c4a0d9bc2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 04:34:36 +0000 Subject: [PATCH 04/19] vllm/v1/worker/gpu_model_runner.py: remove print statement Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c0cdfdc9048..a9bd1edf3cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -757,7 +757,6 @@ def _extract_offsets( for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id - print(new_req_data.lora_request) if (new_req_data.lora_request is not None and new_req_data.lora_request.invocation_tokens is not None): tokens = new_req_data.lora_request.invocation_tokens From 24ff3760b40860dbe1193642c8f2b20b3ed02161 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 04:36:13 +0000 Subject: [PATCH 05/19] vllm/v1/core/sched/scheduler.py: remove debug code Signed-off-by: Thomas Parnell --- vllm/v1/core/sched/scheduler.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4bb2bbd6e2c..2d2274ab6a4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -8,7 +8,6 @@ from collections.abc import Iterable from typing import Any, Optional, Union -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -212,13 +211,6 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, self.max_model_len - request.num_computed_tokens) - if envs.VLLM_V1_USE_DEMO_LOGGING and num_new_tokens > 1: - logger.info("request_id: %s", request.request_id) - logger.info("num_tokens: %d", request.num_tokens) - logger.info("num_computed_tokens: %d", - request.num_computed_tokens) - logger.info("num_new_tokens: %d", num_new_tokens) - # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget @@ -424,12 +416,6 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if envs.VLLM_V1_USE_DEMO_LOGGING: - logger.info("request_id: %s", request.request_id) - logger.info("num_tokens: %d", request.num_tokens) - logger.info("num_computed_tokens: %d", num_computed_tokens) - logger.info("num_new_tokens: %d", num_new_tokens) - new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, From 412eacd5c9dfe805dfa3ef96e08c06363c20b385 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 04:37:01 +0000 Subject: [PATCH 06/19] vllm/envs.py Signed-off-by: Thomas Parnell --- vllm/envs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 92d985c79fc..921052821ee 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -120,7 +120,6 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 - VLLM_V1_USE_DEMO_LOGGING: bool = True VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 @@ -836,10 +835,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), - # Useful for demo - "VLLM_V1_USE_DEMO_LOGGING": - lambda: os.environ.get("VLLM_V1_USE_DEMO_LOGGING", "0") == "1", - # If set, allow insecure serialization using pickle. # This is useful for environments where it is deemed safe to use the # insecure method and it is needed for some reason. From 32098e40787e846304ee0be75fcf9aca724c0b07 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 18:45:11 +0000 Subject: [PATCH 07/19] Inject aLoRA behaviour via mixin Signed-off-by: Thomas Parnell --- vllm/config.py | 2 + vllm/lora/layers.py | 98 ++++++++++++++++++++++++++++----------------- vllm/lora/utils.py | 11 ++++- 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d986ab6b0ed..6c1ad60f2ab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2873,6 +2873,8 @@ class LoRAConfig: allowed.""" bias_enabled: bool = False """Enable bias for LoRA adapters.""" + activated_lora_enabled: bool = True + """Enable Activated LoRA.""" def compute_hash(self) -> str: """ diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index db8a881b089..775c12a2822 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -419,44 +419,14 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - # Extract aLoRA batch metadata from forward context - alora_metadata = get_forward_context().alora_metadata - k_offsets = alora_metadata.k_offsets - query_start_locs = alora_metadata.query_start_locs - - # Build the 1D “save‐prefix” mask: - T = output.size(0) # total tokens - starts = query_start_locs[:-1] # starts and end index of each request - ends = query_start_locs[1:] - lengths = ends - starts # request lengths - kept_lens = lengths - k_offsets - kept_lens = torch.clamp( - kept_lens, - min=0) # portion of request to keep as base model weights - - device = output.device - # Create the alora mask - delta = torch.zeros(T + 1, device=device, dtype=output.dtype) - ends_for_scatter = starts + kept_lens - pos_vals = kept_lens.sign().to(output.dtype) - neg_vals = -pos_vals - delta.scatter_add_(0, starts, pos_vals) - delta.scatter_add_(0, ends_for_scatter, neg_vals) - cums = torch.cumsum(delta[:-1], dim=0) - mask1d = cums > 0 # shape [T], bool - mask2d = mask1d.unsqueeze(1).to(output.dtype) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output - # Clone base layer output before running LoRA - orig_out = output.clone() - - # Apply LoRA in‐place on `output`: - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) - # Apply alora mask - final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) - return final_output + return output @property def weight(self) -> torch.Tensor: @@ -1314,3 +1284,57 @@ def can_replace_layer( def extra_repr(self) -> str: return self.base_layer.extra_repr() + + +class ActivatedLoRAMixin: + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + # Extract aLoRA batch metadata from forward context + alora_metadata = get_forward_context().alora_metadata + k_offsets = alora_metadata.k_offsets + query_start_locs = alora_metadata.query_start_locs + + # Build the 1D “save‐prefix” mask: + T = output.size(0) # total tokens + starts = query_start_locs[:-1] # starts and end index of each request + ends = query_start_locs[1:] + lengths = ends - starts # request lengths + kept_lens = lengths - k_offsets + kept_lens = torch.clamp( + kept_lens, + min=0) # portion of request to keep as base model weights + + device = output.device + # Create the alora mask + delta = torch.zeros(T + 1, device=device, dtype=output.dtype) + ends_for_scatter = starts + kept_lens + pos_vals = kept_lens.sign().to(output.dtype) + neg_vals = -pos_vals + delta.scatter_add_(0, starts, pos_vals) + delta.scatter_add_(0, ends_for_scatter, neg_vals) + cums = torch.cumsum(delta[:-1], dim=0) + mask1d = cums > 0 # shape [T], bool + mask2d = mask1d.unsqueeze(1).to(output.dtype) + + # Clone base layer output before running LoRA + orig_out = output.clone() + + # Apply LoRA in‐place on `output`: + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + # Apply alora mask + final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) + return final_output diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ee196e3f689..6bafc6a00a5 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,7 +21,8 @@ # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, +from vllm.lora.layers import (ActivatedLoRAMixin, BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -67,6 +68,14 @@ def from_layer(layer: nn.Module, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=model_config): + + # inject a-LoRA behaviour + if (lora_config.activated_lora_enabled + and lora_cls is MergedQKVParallelLinearWithLoRA): + lora_cls = type( + lora_cls.__name__.replace("LoRA", "ActivatedLoRA"), + (ActivatedLoRAMixin, lora_cls), {}) + instance_layer = lora_cls(layer) instance_layer.create_lora_weights(max_loras, lora_config, model_config) From fb6d28ec829637aaf816e774353bf8513d9e3891 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 19:16:13 +0000 Subject: [PATCH 08/19] Simpler implementation without mixin Signed-off-by: Thomas Parnell --- vllm/lora/layers.py | 19 +++++++++++++++++-- vllm/lora/utils.py | 13 +++---------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 775c12a2822..5d653a5ac4e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -875,7 +875,8 @@ def can_replace_layer( model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) + and len(packed_modules_list) == 3 + and not lora_config.activated_lora_enabled) #TODO: Implement this @@ -1286,7 +1287,8 @@ def extra_repr(self) -> str: return self.base_layer.extra_repr() -class ActivatedLoRAMixin: +class MergedQKVParallelLinearWithActivatedLoRA(MergedQKVParallelLinearWithLoRA + ): def apply(self, x: torch.Tensor, @@ -1338,3 +1340,16 @@ def apply(self, # Apply alora mask final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) return final_output + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3 + and lora_config.activated_lora_enabled) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 6bafc6a00a5..f5ad741fdcc 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,11 +21,11 @@ # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable -from vllm.lora.layers import (ActivatedLoRAMixin, BaseLayerWithLoRA, - ColumnParallelLinearWithLoRA, +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithActivatedLoRA, MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, ReplicatedLinearWithLoRA, @@ -45,6 +45,7 @@ MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithActivatedLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, @@ -68,14 +69,6 @@ def from_layer(layer: nn.Module, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=model_config): - - # inject a-LoRA behaviour - if (lora_config.activated_lora_enabled - and lora_cls is MergedQKVParallelLinearWithLoRA): - lora_cls = type( - lora_cls.__name__.replace("LoRA", "ActivatedLoRA"), - (ActivatedLoRAMixin, lora_cls), {}) - instance_layer = lora_cls(layer) instance_layer.create_lora_weights(max_loras, lora_config, model_config) From 5f62d8beeb439b7f4ec37c126653ca829c2482fb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 22:15:01 +0000 Subject: [PATCH 09/19] Scan for invocation tokens in one place Signed-off-by: Thomas Parnell --- vllm/forward_context.py | 2 +- vllm/lora/layers.py | 6 +- vllm/lora/request.py | 2 +- vllm/v1/core/kv_cache_utils.py | 22 ++--- vllm/v1/engine/processor.py | 30 ++++++- vllm/v1/worker/gpu_model_runner.py | 129 ++++++++--------------------- 6 files changed, 70 insertions(+), 121 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index cf8447a5a25..54248f89b38 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -29,7 +29,7 @@ @dataclass class ALoRAMetadata: k_offsets: torch.Tensor - query_start_locs: torch.Tensor + query_start_loc: torch.Tensor @dataclass diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5d653a5ac4e..da25de8317e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1305,12 +1305,12 @@ def apply(self, # Extract aLoRA batch metadata from forward context alora_metadata = get_forward_context().alora_metadata k_offsets = alora_metadata.k_offsets - query_start_locs = alora_metadata.query_start_locs + query_start_loc = alora_metadata.query_start_loc # Build the 1D “save‐prefix” mask: T = output.size(0) # total tokens - starts = query_start_locs[:-1] # starts and end index of each request - ends = query_start_locs[1:] + starts = query_start_loc[:-1] # starts and end index of each request + ends = query_start_loc[1:] lengths = ends - starts # request lengths kept_lens = lengths - k_offsets kept_lens = torch.clamp( diff --git a/vllm/lora/request.py b/vllm/lora/request.py index a566e50d605..c5851af8c21 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,7 +33,7 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None - invocation_tokens: Optional[list[int]] = None + invocation_start: Optional[int] = None k_offset: Optional[int] = None def __post_init__(self): diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 7eff333a124..b5a0af9cb76 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -457,20 +457,6 @@ def hash_request_tokens(hash_function: Any, block_size: int, token_ids = request.all_token_ids req_need_extra_keys = need_extra_keys(request) - if (request.lora_request is not None - and request.lora_request.invocation_tokens is not None): - use_alora = True - invocation_tokens = request.lora_request.invocation_tokens - # scan backward for the last match (faster than full forward scan+max) - invocation_start = -1 - n = len(invocation_tokens) - for idx in range(len(token_ids) - n, -1, -1): - if token_ids[idx:idx + n] == invocation_tokens: - # weights activated 1 token after start - invocation_start = idx + 1 - break - else: - use_alora = False req_extra_keys = None curr_mm_idx = 0 @@ -487,8 +473,12 @@ def hash_request_tokens(hash_function: Any, block_size: int, # MM and LoRA requests need extra keys for block-hash computation. req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) - if use_alora and end <= invocation_start: - req_extra_keys = None # cache is equivalent to base model cache + # Respect a-LoRA behaviour + if (request.lora_request is not None + and request.lora_request.invocation_start is not None + and end <= request.lora_request.invocation_start): + # cache is equivalent to base model cache + req_extra_keys = None block_hash = hash_block_tokens(hash_function, parent_block_hash_value, block_token_ids, req_extra_keys) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index d6b0c28f16f..6861525b856 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -330,20 +330,42 @@ def process_inputs( # Tokenize aLoRA invocation sequence if applicable. if lora_request is not None: + # tpa: can we get this from PeftHelper somehow? # Load in adapter config file lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_config_path = os.path.join(lora_path, "adapter_config.json") with open(lora_config_path) as f: config = json.load(f) - if "invocation_string" in config: # check if aLoRA + if "invocation_string" in config: + invocation_tokens = self.input_preprocessor._tokenize_prompt( config["invocation_string"], lora_request=lora_request, tokenization_kwargs=tokenization_kwargs) - # Make it an aLoRA request - # (in future, this will happen upstream) - lora_request.invocation_tokens = invocation_tokens + + invocation_start = -1 + n = len(invocation_tokens) + token_ids = decoder_inputs["prompt_token_ids"] + + if n > 0 and len(token_ids) >= n: + # scan backward for the last match + # (faster than full forward scan+max) + for idx in range(len(token_ids) - n, -1, -1): + if token_ids[idx:idx + n] == invocation_tokens: + # weights activated 1 token after start + invocation_start = idx + 1 + break + + if invocation_start == -1: + raise ValueError( + "Invocation sequence not found in prompt " + f"for request '{request_id}'. aLoRA models require the " + "invocation tokens to be present in the input.") + + lora_request.invocation_start = invocation_start + lora_request.k_offset = len(token_ids) - invocation_start + return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a9bd1edf3cc..83fc2f7d386 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -557,7 +557,7 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata]]: + Optional[SpecDecodeMetadata], Optional[ALoRAMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -743,84 +743,28 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata) - - def _extract_offsets( - self, - scheduler_output: "SchedulerOutput", - ) -> ALoRAMetadata: - """ - Extract k_offsets for each new scheduled req that is called with aLoRA. - Prepare aLoRA metadata for model execution. - """ - - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - if (new_req_data.lora_request is not None and - new_req_data.lora_request.invocation_tokens is not None): - tokens = new_req_data.lora_request.invocation_tokens - prompt_ids = new_req_data.prompt_token_ids - n = len(tokens) - k_offset = -1 - # only bother if there actually are invocation tokens - if n > 0 and len(prompt_ids) >= n: - # scan backward for the last match - # (faster than full forward scan+max) - for idx in range(len(prompt_ids) - n, -1, -1): - if prompt_ids[idx:idx + n] == tokens: - # offset = number of tokens from the start - # of that match to the end of the prompt - k_offset = len(prompt_ids) - idx - 1 - break - if k_offset == -1: - raise ValueError( - "Invocation sequence not found in prompt " - f"for request '{req_id}'. aLoRA models require the " - "invocation tokens to be present in the input.") - + # Compute a-LoRA metadata + if self.lora_config.activated_lora_enabled: + k_offsets = [1] * (num_reqs) + for req_id in self.input_batch.req_ids: + req_index = self.input_batch.req_id_to_index[req_id] cached_lora_request = self.requests[req_id].lora_request - assert cached_lora_request is not None - cached_lora_request.k_offset = k_offset - - # Fill in k_offsets based on the `scheduled_new_reqs` and - # `scheduled_cached_reqs` within the SchedulerOutput. - num_seqs = len(self.query_start_loc_np.tolist()) - 1 - k_offsets = [1] * (num_seqs) - - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - req_index = self.input_batch.req_id_to_index[req_id] - cached_lora_request = self.requests[req_id].lora_request - if (cached_lora_request is not None - and cached_lora_request.k_offset is not None): - k_offsets[req_index] = cached_lora_request.k_offset - else: - k_offsets[req_index] = len( - self.requests[req_id].prompt_token_ids) - - for cached_req_data in scheduler_output.scheduled_cached_reqs: - req_id = cached_req_data.req_id - req_index = self.input_batch.req_id_to_index[req_id] - cached_lora_request = self.requests[req_id].lora_request - if (cached_lora_request is not None - and cached_lora_request.k_offset is not None): - k_offsets[req_index] = cached_lora_request.k_offset - else: - k_offsets[req_index] = len( - self.requests[req_id].prompt_token_ids) - - query_locs = torch.tensor(self.query_start_loc_np.tolist(), - device=self.device) - - if len(query_locs) > self.input_batch.num_reqs + 1: - query_locs[self.input_batch.num_reqs + 1:] = 0 + if (cached_lora_request is not None + and cached_lora_request.k_offset is not None): + k_offsets[req_index] = cached_lora_request.k_offset + else: + k_offsets[req_index] = len( + self.requests[req_id].prompt_token_ids) - alora_metadata = ALoRAMetadata(k_offsets=torch.tensor( - k_offsets, device=self.device), - query_start_locs=query_locs) + alora_metadata = ALoRAMetadata( + k_offsets=torch.tensor(k_offsets, device=self.device), + query_start_loc=query_start_loc.to(torch.int64), + ) + else: + alora_metadata = None - return alora_metadata + return (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, alora_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1284,11 +1228,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) - - # tpa - let's do this in prepare input> - # Extract the aLoRA offsets if applicable. - alora_metadata = self._extract_offsets(scheduler_output) + spec_decode_metadata, + alora_metadata) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph @@ -1954,27 +1895,23 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - # Prepare dummy ALoRAMetadata - dummy_k_offsets = torch.tensor([1] * max_num_reqs, - device=self.device) - dummy_cu_num_tokens = np.cumsum(num_scheduled_tokens) - dummy_query_start_loc = [0] * (max_num_reqs + 1) - dummy_query_start_loc[0] = 0 - dummy_query_start_loc[1:num_reqs + 1] = dummy_cu_num_tokens - dummy_query_start_loc = torch.tensor(dummy_query_start_loc, - device=self.device) - dummy_alora_metadata = ALoRAMetadata( - k_offsets=dummy_k_offsets, - query_start_locs=dummy_query_start_loc, - ) - #num_reqs=num_reqs,) + if self.lora_config.activated_lora_enabled: + k_offsets = torch.tensor([1] * num_reqs, device=self.device) + query_start_loc = self.query_start_loc[:num_reqs + 1].to( + torch.int64) + alora_metadata = ALoRAMetadata( + k_offsets=k_offsets, + query_start_loc=query_start_loc, + ) + else: + alora_metadata = None with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - alora_metadata=dummy_alora_metadata): + alora_metadata=alora_metadata): outputs = model( input_ids=input_ids, positions=positions, From f9396b0d01d0f9275fdd90e17075de05d59da9f1 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 18 Jun 2025 22:23:35 +0000 Subject: [PATCH 10/19] Just use single field in request Signed-off-by: Thomas Parnell --- vllm/lora/request.py | 1 - vllm/v1/core/kv_cache_utils.py | 4 ++-- vllm/v1/engine/processor.py | 9 ++++----- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index c5851af8c21..d0f39f85219 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,7 +33,6 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None - invocation_start: Optional[int] = None k_offset: Optional[int] = None def __post_init__(self): diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b5a0af9cb76..6dbbe208b78 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -475,8 +475,8 @@ def hash_request_tokens(hash_function: Any, block_size: int, request, start, end, curr_mm_idx) # Respect a-LoRA behaviour if (request.lora_request is not None - and request.lora_request.invocation_start is not None - and end <= request.lora_request.invocation_start): + and request.lora_request.k_offset is not None and end + <= (len(token_ids) - request.lora_request.k_offset)): # cache is equivalent to base model cache req_extra_keys = None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6861525b856..58abf1e1e6b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -344,7 +344,7 @@ def process_inputs( lora_request=lora_request, tokenization_kwargs=tokenization_kwargs) - invocation_start = -1 + k_offset = -1 n = len(invocation_tokens) token_ids = decoder_inputs["prompt_token_ids"] @@ -354,17 +354,16 @@ def process_inputs( for idx in range(len(token_ids) - n, -1, -1): if token_ids[idx:idx + n] == invocation_tokens: # weights activated 1 token after start - invocation_start = idx + 1 + k_offset = len(token_ids) - idx - 1 break - if invocation_start == -1: + if k_offset == -1: raise ValueError( "Invocation sequence not found in prompt " f"for request '{request_id}'. aLoRA models require the " "invocation tokens to be present in the input.") - lora_request.invocation_start = invocation_start - lora_request.k_offset = len(token_ids) - invocation_start + lora_request.k_offset = k_offset return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, From 6f36f6d0f93c670bfe4605c9e89a3b6d99a8883f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 04:08:40 +0000 Subject: [PATCH 11/19] Use peft_helper instead of reading files directly Signed-off-by: Thomas Parnell --- vllm/lora/peft_helper.py | 2 ++ vllm/v1/engine/processor.py | 21 +++++++++------------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index a20d73f0f72..edd7be16fdf 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -37,6 +37,8 @@ class PEFTHelper: use_dora: bool = field(default=False) # long context lora field context_length: int = field(default=0) + # Invocation string for Activated LoRA (aLoRA, see: https://arxiv.org/abs/2504.12397) + invocation_string: Optional[str] = field(default=None) # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 58abf1e1e6b..0a6de2c3b08 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json -import os import time from collections.abc import Mapping, Sequence from typing import Any, Literal, Optional, Union @@ -11,8 +9,8 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.multimodal.inputs import PlaceholderRange @@ -328,19 +326,18 @@ def process_inputs( sorted_mm_inputs = orig_sorted_mm_inputs # Tokenize aLoRA invocation sequence if applicable. - if lora_request is not None: + if self.lora_config.activated_lora_enabled and lora_request is not None: - # tpa: can we get this from PeftHelper somehow? - # Load in adapter config file - lora_path = get_adapter_absolute_path(lora_request.lora_path) - lora_config_path = os.path.join(lora_path, "adapter_config.json") - with open(lora_config_path) as f: - config = json.load(f) + text_config = self.model_config.hf_config.get_text_config() - if "invocation_string" in config: + peft_helper = PEFTHelper.from_local_dir( + lora_request.lora_path, text_config.max_position_embeddings, + lora_request.tensorizer_config_dict) + + if peft_helper.invocation_string is not None: invocation_tokens = self.input_preprocessor._tokenize_prompt( - config["invocation_string"], + peft_helper.invocation_string, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs) From c6ffe8f8aca86faafae720a8e5b2f994d577f3e0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 04:12:10 +0000 Subject: [PATCH 12/19] Remove online example for now. Signed-off-by: Thomas Parnell --- ...ra_testing.py => alora_offline_example.py} | 0 examples/alora/alora_server_testing.py | 67 ------------------- examples/alora/alora_server_testing.sh | 46 ------------- 3 files changed, 113 deletions(-) rename examples/alora/{new_alora_testing.py => alora_offline_example.py} (100%) delete mode 100644 examples/alora/alora_server_testing.py delete mode 100644 examples/alora/alora_server_testing.sh diff --git a/examples/alora/new_alora_testing.py b/examples/alora/alora_offline_example.py similarity index 100% rename from examples/alora/new_alora_testing.py rename to examples/alora/alora_offline_example.py diff --git a/examples/alora/alora_server_testing.py b/examples/alora/alora_server_testing.py deleted file mode 100644 index e9616600a54..00000000000 --- a/examples/alora/alora_server_testing.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# After starting server using "vllm serve --enable_lora --lora_modules..." - -import time - -from openai import OpenAI - -model_id = "ibm-granite/granite-3.2-8b-instruct" - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - -BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" -ALORA_NAME = "new_alora" # "ibm-granite/granite-3.2-8b-alora-uncertainty" -invocation_string = "<|start_of_role|>certainty<|end_of_role|>" - -################################################################### -prompts = [ - "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", - "What is MIT?", - ( - "<|start_of_role|>user<|end_of_role|>What is the capital of " - "Massachusetts?<|end_of_text|>\n" - ), - "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", - ( - "<|start_of_role|>user<|end_of_role|>What is the capital of " - "Massachusetts?<|end_of_text|>\n" - ), - "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>", -] - -# Base model call -outputs_base = client.completions.create( - model=BASE_NAME, prompt=prompts, temperature=0, max_tokens=600 -) - -choices = outputs_base.choices -generated_text = [] -for i in range(len(prompts)): - prompt = prompts[i] - - generated_text += [outputs_base.choices[i].text] - print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") - -prompts_alora = [ - x + y + "<|end_of_text|>\n" + invocation_string - for x, y in zip(prompts, generated_text) -] - -# Base model with aLoRA call -t0 = time.time() -alora_outputs = client.completions.create( - model=ALORA_NAME, prompt=prompts_alora, temperature=0, max_tokens=10 -) -t = time.time() - t0 -print(f"Time: {t}") -for i in range(len(prompts_alora)): - prompt = prompts_alora[i] - generated_text = alora_outputs.choices[i].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/alora/alora_server_testing.sh b/examples/alora/alora_server_testing.sh deleted file mode 100644 index 49eb9c5612f..00000000000 --- a/examples/alora/alora_server_testing.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -# More documentation: https://docs.vllm.ai/en/v0.8.3/serving/openai_compatible_server.html#vllm-serve -export VLLM_USE_V1="1" -# Specify base model (and optionally loras) to load in when starting the server. -vllm serve ibm-granite/granite-3.2-8b-instruct \ - --enable-lora \ - --lora-modules '{"name": "new_alora", "path": "/proj/dmfexp/statllm/users/kgreenewald/.cache/huggingface/models/hub/models--ibm-granite--granite-3.2-8b-alora-uncertainty/snapshots/6109ad88201426003e696d023ec67c19e7f3d444", "base_model_name": "ibm-granite/granite-3.2-8b-instruct"}' \ - --dtype bfloat16 \ - --max-lora-rank 64 \ - --enable-prefix-caching -#--no-enable-prefix-caching -# Check that the lora model is listed along with other models. -#curl localhost:8000/v1/models | jq . - -########################################### - -# A second option is to enable dynamic adapter loading instead of at start-up. -#export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True - -#curl -X POST http://localhost:8000/v1/load_lora_adapter \ -#-H "Content-Type: application/json" \ -#-d '{ -# "lora_name": "new_alora", -# "lora_path": "/path/to/new_alora" -#}' -# Should return "200 OK - Success: LoRA adapter 'new_alora' added successfully" - -# Example of dynamically unloading an adapter. -# curl -X POST http://localhost:8000/v1/unload_lora_adapter \ -# -H "Content-Type: application/json" \ -# -d '{ -# "lora_name": "new_alora" -# }' - -########################################### - -# Send a request using the new aLoRA -#curl http://localhost:8000/v1/completions \ -# -H "Content-Type: application/json" \ -# -d '{ -# "model": "new_alora", -# "prompt": ""What is MIT?"", -# "max_tokens": 600, -# "temperature": 0 -# }' | jq From 4a4b568ccbaf67961883ec0900c7d36ead99b3aa Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 09:43:19 +0000 Subject: [PATCH 13/19] Further simplification; works with chunked prefill; correct output with torch.compile Signed-off-by: Thomas Parnell --- vllm/forward_context.py | 3 +-- vllm/lora/layers.py | 25 ++--------------- vllm/lora/request.py | 2 +- vllm/model_executor/layers/linear.py | 14 +++++----- vllm/v1/core/kv_cache_utils.py | 4 +-- vllm/v1/engine/processor.py | 8 +++--- vllm/v1/worker/gpu_model_runner.py | 40 +++++++++++++++------------- 7 files changed, 40 insertions(+), 56 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 54248f89b38..a9870dae6e6 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -28,8 +28,7 @@ @dataclass class ALoRAMetadata: - k_offsets: torch.Tensor - query_start_loc: torch.Tensor + mask1d: torch.Tensor @dataclass diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index da25de8317e..145d3bbcc5b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1304,29 +1304,8 @@ def apply(self, # Extract aLoRA batch metadata from forward context alora_metadata = get_forward_context().alora_metadata - k_offsets = alora_metadata.k_offsets - query_start_loc = alora_metadata.query_start_loc - - # Build the 1D “save‐prefix” mask: - T = output.size(0) # total tokens - starts = query_start_loc[:-1] # starts and end index of each request - ends = query_start_loc[1:] - lengths = ends - starts # request lengths - kept_lens = lengths - k_offsets - kept_lens = torch.clamp( - kept_lens, - min=0) # portion of request to keep as base model weights - - device = output.device - # Create the alora mask - delta = torch.zeros(T + 1, device=device, dtype=output.dtype) - ends_for_scatter = starts + kept_lens - pos_vals = kept_lens.sign().to(output.dtype) - neg_vals = -pos_vals - delta.scatter_add_(0, starts, pos_vals) - delta.scatter_add_(0, ends_for_scatter, neg_vals) - cums = torch.cumsum(delta[:-1], dim=0) - mask1d = cums > 0 # shape [T], bool + + mask1d = alora_metadata.mask1d mask2d = mask1d.unsqueeze(1).to(output.dtype) # Clone base layer output before running LoRA diff --git a/vllm/lora/request.py b/vllm/lora/request.py index d0f39f85219..64762f15ec2 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,7 +33,7 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None - k_offset: Optional[int] = None + invocation_start: Optional[int] = None def __post_init__(self): if self.lora_local_path: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index ddf55bc49ad..9d506b87539 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -230,12 +230,14 @@ def __init__( ): super().__init__() - # lets torch.compile know that forward_context needs to be - # considered as an input to the layer (copied from attention) - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + vllm_config = get_current_vllm_config() + if vllm_config.lora_config.activated_lora_enabled: + # lets torch.compile know that forward_context needs to be + # considered as an input to the layer (copied from attention) + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self # Keep input parameters self.input_size = input_size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6dbbe208b78..b5a0af9cb76 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -475,8 +475,8 @@ def hash_request_tokens(hash_function: Any, block_size: int, request, start, end, curr_mm_idx) # Respect a-LoRA behaviour if (request.lora_request is not None - and request.lora_request.k_offset is not None and end - <= (len(token_ids) - request.lora_request.k_offset)): + and request.lora_request.invocation_start is not None + and end <= request.lora_request.invocation_start): # cache is equivalent to base model cache req_extra_keys = None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 0a6de2c3b08..40ac57951c2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -341,7 +341,7 @@ def process_inputs( lora_request=lora_request, tokenization_kwargs=tokenization_kwargs) - k_offset = -1 + invocation_start = -1 n = len(invocation_tokens) token_ids = decoder_inputs["prompt_token_ids"] @@ -351,16 +351,16 @@ def process_inputs( for idx in range(len(token_ids) - n, -1, -1): if token_ids[idx:idx + n] == invocation_tokens: # weights activated 1 token after start - k_offset = len(token_ids) - idx - 1 + invocation_start = idx + 1 break - if k_offset == -1: + if invocation_start == -1: raise ValueError( "Invocation sequence not found in prompt " f"for request '{request_id}'. aLoRA models require the " "invocation tokens to be present in the input.") - lora_request.k_offset = k_offset + lora_request.invocation_start = invocation_start return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 83fc2f7d386..1bed6d3c680 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -229,6 +229,11 @@ def __init__( dtype=torch.int64, device=self.device) + if self.lora_config.activated_lora_enabled: + self.mask1d = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -745,21 +750,24 @@ def _prepare_inputs( # Compute a-LoRA metadata if self.lora_config.activated_lora_enabled: - k_offsets = [1] * (num_reqs) + invocation_start = np.empty(shape=(num_reqs, ), dtype=int) for req_id in self.input_batch.req_ids: req_index = self.input_batch.req_id_to_index[req_id] cached_lora_request = self.requests[req_id].lora_request if (cached_lora_request is not None - and cached_lora_request.k_offset is not None): - k_offsets[req_index] = cached_lora_request.k_offset + and cached_lora_request.invocation_start is not None): + invocation_start[ + req_index] = cached_lora_request.invocation_start else: - k_offsets[req_index] = len( + invocation_start[req_index] = len( self.requests[req_id].prompt_token_ids) - - alora_metadata = ALoRAMetadata( - k_offsets=torch.tensor(k_offsets, device=self.device), - query_start_loc=query_start_loc.to(torch.int64), - ) + mask1d_cpu = torch.tensor(positions_np + < invocation_start[req_indices], + dtype=torch.bool, + device="cpu") + mask1d = self.mask1d[:total_num_scheduled_tokens] + mask1d.copy_(mask1d_cpu, non_blocking=True) + alora_metadata = ALoRAMetadata(mask1d=mask1d) else: alora_metadata = None @@ -1895,16 +1903,12 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + alora_metadata = None if self.lora_config.activated_lora_enabled: - k_offsets = torch.tensor([1] * num_reqs, device=self.device) - query_start_loc = self.query_start_loc[:num_reqs + 1].to( - torch.int64) - alora_metadata = ALoRAMetadata( - k_offsets=k_offsets, - query_start_loc=query_start_loc, - ) - else: - alora_metadata = None + mask1d = self.mask1d[:num_tokens] + alora_metadata = ALoRAMetadata(mask1d=mask1d) + # needed to avoid guard failures + torch._dynamo.mark_dynamic(alora_metadata.mask1d, 0) with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, From 4cbef84038c4406db6b4d40cd70a477fb358484f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 14:10:26 +0000 Subject: [PATCH 14/19] Add enable_activated_lora engine arg Signed-off-by: Thomas Parnell --- examples/alora/alora_offline_example.py | 6 ++---- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 4 ++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/alora/alora_offline_example.py b/examples/alora/alora_offline_example.py index 5d3908ff8f2..c3b049ec587 100644 --- a/examples/alora/alora_offline_example.py +++ b/examples/alora/alora_offline_example.py @@ -10,11 +10,11 @@ from vllm.lora.request import LoRARequest BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" + ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" invocation_string = "<|start_of_role|>certainty<|end_of_role|>" os.environ["VLLM_USE_V1"] = "1" -os.environ["VLLM_V1_USE_DEMO_LOGGING"] = "1" # download your LoRA adapter to ~/.cache/huggingface/… alora_path = snapshot_download(repo_id=ALORA_NAME) @@ -26,11 +26,9 @@ llm = LLM( model=BASE_NAME, enable_lora=True, - enforce_eager=True, + enable_activated_lora=True, dtype=torch.bfloat16, - enable_prefix_caching=True, # enable APC max_lora_rank=64, - enable_chunked_prefill=False, ) prompts = [ diff --git a/vllm/config.py b/vllm/config.py index 6c1ad60f2ab..fdbb6364301 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2873,7 +2873,7 @@ class LoRAConfig: allowed.""" bias_enabled: bool = False """Enable bias for LoRA adapters.""" - activated_lora_enabled: bool = True + activated_lora_enabled: bool = False """Enable Activated LoRA.""" def compute_hash(self) -> str: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f599d7a3bb5..e1751ccb8ef 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -355,6 +355,7 @@ class EngineArgs: # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled + enable_activated_lora: bool = LoRAConfig.activated_lora_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras @@ -733,6 +734,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If True, enable handling of LoRA adapters.") lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--enable-activated-lora", + **lora_kwargs["activated_lora"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) @@ -1190,6 +1193,7 @@ def create_engine_config( lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, + activated_lora_enabled=self.enable_activated_lora, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, From 49a5bdc9e617e63f52e62d563027c80f1961095c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 14:19:48 +0000 Subject: [PATCH 15/19] Disable tqdm in example Signed-off-by: Thomas Parnell --- examples/alora/alora_offline_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/alora/alora_offline_example.py b/examples/alora/alora_offline_example.py index c3b049ec587..4133ac341d2 100644 --- a/examples/alora/alora_offline_example.py +++ b/examples/alora/alora_offline_example.py @@ -43,6 +43,7 @@ outputsBase = llm.generate( prompts, sampling_params, + use_tqdm=False, ) generated_text = [] for output in outputsBase: @@ -62,6 +63,7 @@ prompts_alora, sampling_params, lora_request=LoRARequest("UQ_adapter", 1, alora_path), + use_tqdm=False, ) t = time.time() - t0 print(f"Time: {t}") From 5c2e1815900b1d3075ed95535c0be66510575ad2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 14:46:09 +0000 Subject: [PATCH 16/19] Trigger Build Signed-off-by: Thomas Parnell From ceae7c7f45040a602fa919471d9d9b752995a652 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 17:04:29 +0000 Subject: [PATCH 17/19] vllm/model_executor/layers/linear.py: check lora_config exists before checking activated lora flag Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9d506b87539..0648e8e6a19 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -231,7 +231,8 @@ def __init__( super().__init__() vllm_config = get_current_vllm_config() - if vllm_config.lora_config.activated_lora_enabled: + if (vllm_config.lora_config + and vllm_config.lora_config.activated_lora_enabled): # lets torch.compile know that forward_context needs to be # considered as an input to the layer (copied from attention) compilation_config = vllm_config.compilation_config From 99b8b60ad7ef004b7343ebb9f6eea77667377b2f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 17:23:45 +0000 Subject: [PATCH 18/19] arg_utils.py: fix typo Signed-off-by: Thomas Parnell --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5de42cac7fe..82f33acc4fe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -735,7 +735,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--enable-activated-lora", - **lora_kwargs["activated_lora"]) + **lora_kwargs["activated_lora_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) From 477ab6eb14d01f8a51562a538bd6c8c76adff6eb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 19 Jun 2025 18:27:57 +0000 Subject: [PATCH 19/19] Additional checking of lora_config Signed-off-by: Thomas Parnell --- vllm/v1/engine/processor.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e531363829f..c783ced22bd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -330,7 +330,8 @@ def process_inputs( sorted_mm_inputs = orig_sorted_mm_inputs # Tokenize aLoRA invocation sequence if applicable. - if self.lora_config.activated_lora_enabled and lora_request is not None: + if (self.lora_config and self.lora_config.activated_lora_enabled + and lora_request is not None): text_config = self.model_config.hf_config.get_text_config() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e5ab1018bbc..917931eacbe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -238,7 +238,7 @@ def __init__( dtype=torch.int64, device=self.device) - if self.lora_config.activated_lora_enabled: + if self.lora_config and self.lora_config.activated_lora_enabled: self.mask1d = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) @@ -762,7 +762,7 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) # Compute a-LoRA metadata - if self.lora_config.activated_lora_enabled: + if self.lora_config and self.lora_config.activated_lora_enabled: invocation_start = np.empty(shape=(num_reqs, ), dtype=int) for req_id in self.input_batch.req_ids: req_index = self.input_batch.req_id_to_index[req_id] @@ -1967,7 +1967,7 @@ def _dummy_run( num_tokens, None, False) alora_metadata = None - if self.lora_config.activated_lora_enabled: + if self.lora_config and self.lora_config.activated_lora_enabled: mask1d = self.mask1d[:num_tokens] alora_metadata = ALoRAMetadata(mask1d=mask1d) # needed to avoid guard failures