Skip to content

[Model] Activated LoRA #19710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/alora/alora_offline_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time

import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

BASE_NAME = "ibm-granite/granite-3.2-8b-instruct"

ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty"
invocation_string = "<|start_of_role|>certainty<|end_of_role|>"

os.environ["VLLM_USE_V1"] = "1"

# download your LoRA adapter to ~/.cache/huggingface/…
alora_path = snapshot_download(repo_id=ALORA_NAME)

print(alora_path)
#######################################


llm = LLM(
model=BASE_NAME,
enable_lora=True,
enable_activated_lora=True,
dtype=torch.bfloat16,
max_lora_rank=64,
)

prompts = [
(
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n"
"<|start_of_role|>assistant<|end_of_role|>"
),
]

sampling_params = SamplingParams(temperature=0, max_tokens=600)

outputsBase = llm.generate(
prompts,
sampling_params,
use_tqdm=False,
)
generated_text = []
for output in outputsBase:
prompt = output.prompt
generated_text += [output.outputs[0].text]
print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}")

prompts_alora = [
x + y + "<|end_of_text|>\n" + invocation_string
for x, y in zip(prompts, generated_text)
]

sampling_params = SamplingParams(temperature=0, max_tokens=10)

t0 = time.time()
outputs = llm.generate(
prompts_alora,
sampling_params,
lora_request=LoRARequest("UQ_adapter", 1, alora_path),
use_tqdm=False,
)
t = time.time() - t0
print(f"Time: {t}")

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,8 @@ class LoRAConfig:
allowed."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
activated_lora_enabled: bool = False
"""Enable Activated LoRA."""

def compute_hash(self) -> str:
"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ class EngineArgs:
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
enable_activated_lora: bool = LoRAConfig.activated_lora_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
Expand Down Expand Up @@ -733,6 +734,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="If True, enable handling of LoRA adapters.")
lora_group.add_argument("--enable-lora-bias",
**lora_kwargs["bias_enabled"])
lora_group.add_argument("--enable-activated-lora",
**lora_kwargs["activated_lora_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank",
**lora_kwargs["max_lora_rank"])
Expand Down Expand Up @@ -1191,6 +1194,7 @@ def create_engine_config(

lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
activated_lora_enabled=self.enable_activated_lora,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
Expand Down
8 changes: 8 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
batchsize_forward_time: defaultdict = defaultdict(list)


@dataclass
class ALoRAMetadata:
mask1d: torch.Tensor


@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
Expand Down Expand Up @@ -94,6 +99,7 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
alora_metadata: Optional[ALoRAMetadata] = None
skip_cuda_graphs: bool = False


Expand All @@ -116,6 +122,7 @@ def set_forward_context(
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
alora_metadata: Optional[ALoRAMetadata] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -140,6 +147,7 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
alora_metadata=alora_metadata,
skip_cuda_graphs=skip_cuda_graphs,
)

Expand Down
51 changes: 50 additions & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
from vllm.forward_context import get_forward_context
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
Expand Down Expand Up @@ -874,7 +875,8 @@ def can_replace_layer(
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3)
and len(packed_modules_list) == 3
and not lora_config.activated_lora_enabled)


#TODO: Implement this
Expand Down Expand Up @@ -1283,3 +1285,50 @@ def can_replace_layer(

def extra_repr(self) -> str:
return self.base_layer.extra_repr()


class MergedQKVParallelLinearWithActivatedLoRA(MergedQKVParallelLinearWithLoRA
):

def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)

# Extract aLoRA batch metadata from forward context
alora_metadata = get_forward_context().alora_metadata

mask1d = alora_metadata.mask1d
mask2d = mask1d.unsqueeze(1).to(output.dtype)

# Clone base layer output before running LoRA
orig_out = output.clone()

# Apply LoRA in‐place on `output`:
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked, 1.0,
self.output_slices)
# Apply alora mask
final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d)
return final_output

@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3
and lora_config.activated_lora_enabled)
2 changes: 2 additions & 0 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class PEFTHelper:
use_dora: bool = field(default=False)
# long context lora field
context_length: int = field(default=0)
# Invocation string for Activated LoRA (aLoRA, see: https://arxiv.org/abs/2504.12397)
invocation_string: Optional[str] = field(default=None)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
Expand Down
1 change: 1 addition & 0 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LoRARequest(
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
tensorizer_config_dict: Optional[dict] = None
invocation_start: Optional[int] = None

def __post_init__(self):
if self.lora_local_path:
Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithActivatedLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
Expand All @@ -44,6 +45,7 @@
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithActivatedLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand Down Expand Up @@ -229,6 +230,16 @@ def __init__(
):
super().__init__()

vllm_config = get_current_vllm_config()
if (vllm_config.lora_config
and vllm_config.lora_config.activated_lora_enabled):
# lets torch.compile know that forward_context needs to be
# considered as an input to the layer (copied from attention)
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ def hash_request_tokens(hash_function: Any, block_size: int,
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
# Respect a-LoRA behaviour
if (request.lora_request is not None
and request.lora_request.invocation_start is not None
and end <= request.lora_request.invocation_start):
# cache is equivalent to base model cache
req_extra_keys = None

block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
Expand Down
39 changes: 39 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
Expand Down Expand Up @@ -328,6 +329,44 @@ def process_inputs(
else:
sorted_mm_inputs = orig_sorted_mm_inputs

# Tokenize aLoRA invocation sequence if applicable.
if (self.lora_config and self.lora_config.activated_lora_enabled
and lora_request is not None):

text_config = self.model_config.hf_config.get_text_config()

peft_helper = PEFTHelper.from_local_dir(
lora_request.lora_path, text_config.max_position_embeddings,
lora_request.tensorizer_config_dict)

if peft_helper.invocation_string is not None:

invocation_tokens = self.input_preprocessor._tokenize_prompt(
peft_helper.invocation_string,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs)

invocation_start = -1
n = len(invocation_tokens)
token_ids = decoder_inputs["prompt_token_ids"]

if n > 0 and len(token_ids) >= n:
# scan backward for the last match
# (faster than full forward scan+max)
for idx in range(len(token_ids) - n, -1, -1):
if token_ids[idx:idx + n] == invocation_tokens:
# weights activated 1 token after start
invocation_start = idx + 1
break

if invocation_start == -1:
raise ValueError(
"Invocation sequence not found in prompt "
f"for request '{request_id}'. aLoRA models require the "
"invocation tokens to be present in the input.")

lora_request.invocation_start = invocation_start

return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
Expand Down
Loading