|
| 1 | +# Adding a New Model |
| 2 | + |
| 3 | +This guide demonstrates how to integrate a novel or customized model into vllm-ascend. For foundational concepts, it is highly recommended to refer to |
| 4 | +[vllm official doc: Adding a New Model](https://docs.vllm.ai/en/stable/contributing/model/) first. |
| 5 | + |
| 6 | +## Step 1: Implementing Models with `torch` and `torch_npu` |
| 7 | + |
| 8 | +This section provides instructions for implementing new models compatible with vllm and vllm-ascend. |
| 9 | + |
| 10 | +**Before starting:** |
| 11 | + |
| 12 | +- Verify whether your model already exists in vllm's [models](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) directory. |
| 13 | +- Use existing models' implementation as templates to accelerate your development. |
| 14 | + |
| 15 | +### Method 1: Implementing New Models from Scratch |
| 16 | + |
| 17 | +Follow vllm's [OPT model adaptation](https://docs.vllm.ai/en/stable/contributing/model/basic.html) example for guidance. |
| 18 | + |
| 19 | +**Key implementation requirements:** |
| 20 | + |
| 21 | +1. Place model files in `vllm_ascend/models/` directory. |
| 22 | + |
| 23 | +2. Standard module structure for decoder-only LLMs (please checkout vllm's implementations for other kinds of model): |
| 24 | + |
| 25 | +- `*ModelForCausalLM` (top-level wrapper) |
| 26 | +- `*Model` (main architecture) |
| 27 | +- `*DecoderLayer` (transformer block) |
| 28 | +- `*Attention` and `*MLP` (specific computation unit) |
| 29 | + |
| 30 | +:::{note} |
| 31 | +`*` denotes your model's unique identifier. |
| 32 | +::: |
| 33 | + |
| 34 | +3. Critical Implementation Details: |
| 35 | + |
| 36 | +All modules must include a `prefix` argument in `__init__()`. |
| 37 | + |
| 38 | +**Required interfaces:** |
| 39 | + |
| 40 | +| Module Type | Required Methods | |
| 41 | +| :------------------- | :---------------------------------------- | |
| 42 | +| `*ModelForCausalLM` | `get_input_embeddings`, `compute_logits`, `load_weights` | |
| 43 | +| `*Model` | `get_input_embeddings`, `load_weights` | |
| 44 | + |
| 45 | +4. Attention Backend Integration: |
| 46 | + |
| 47 | +Importing attention via `from vllm.attention import Attention` can automatically leverage the attention backend routing of vllm-ascend (see: `get_attn_backend_cls()` in `vllm_ascend/platform.py`). |
| 48 | + |
| 49 | +5. Tensor Parallelism: |
| 50 | + |
| 51 | +Use vllm's parallel layers (`ColumnParallelLinear`, `VocabParallelEmbedding`, etc.) to implement models supporting tensor parallelism. Note that Ascend-specific customizations are implemented in `vllm_ascend/ops/` directory (RMSNorm, VocabParallelEmbedding, etc.). |
| 52 | + |
| 53 | +**Reference Implementation Template** (assumed path: `vllm_ascend/models/custom_model.py`): |
| 54 | + |
| 55 | +```python |
| 56 | +from collections.abc import Iterable |
| 57 | +from typing import Optional, Union |
| 58 | + |
| 59 | +import torch |
| 60 | +from torch import nn |
| 61 | +from vllm.attention import Attention |
| 62 | +from vllm.config import VllmConfig |
| 63 | +from vllm.sequence import IntermediateTensors |
| 64 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 65 | + |
| 66 | +class CustomAttention(nn.Module): |
| 67 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 68 | + super().__init__() |
| 69 | + self.attn = Attention(prefix=f"{prefix}.attn") |
| 70 | + |
| 71 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 72 | + # Implement attention logic |
| 73 | + ... |
| 74 | + |
| 75 | +class CustomDecoderLayer(nn.Module): |
| 76 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 77 | + super().__init__() |
| 78 | + self.self_attn = CustomAttention(vllm_config, prefix=f"{prefix}.self_attn") |
| 79 | + |
| 80 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 81 | + # Implement decoder layer |
| 82 | + ... |
| 83 | + |
| 84 | +class CustomModel(nn.Module): |
| 85 | + def __init__(self, vllm_config: VllmConfig, prefix: str): |
| 86 | + super().__init__() |
| 87 | + self.layers = nn.ModuleList([ |
| 88 | + CustomDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") |
| 89 | + for i in range(vllm_config.model_config.hf_config.num_hidden_layers) |
| 90 | + ]) |
| 91 | + |
| 92 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 93 | + ... |
| 94 | + |
| 95 | + def forward( |
| 96 | + self, |
| 97 | + input_ids: torch.Tensor, |
| 98 | + positions: torch.Tensor, |
| 99 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 100 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 101 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 102 | + ... |
| 103 | + |
| 104 | + def load_weights(self, |
| 105 | + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 106 | + ... |
| 107 | + |
| 108 | +class CustomModelForCausalLM(nn.Module): |
| 109 | + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): |
| 110 | + super().__init__() |
| 111 | + self.model = CustomModel(vllm_config, prefix=f"{prefix}.model") |
| 112 | + |
| 113 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 114 | + ... |
| 115 | + |
| 116 | + def forward( |
| 117 | + self, |
| 118 | + input_ids: torch.Tensor, |
| 119 | + positions: torch.Tensor, |
| 120 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 121 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 122 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 123 | + ... |
| 124 | + |
| 125 | + def compute_logits(self, |
| 126 | + hidden_states: torch.Tensor, |
| 127 | + sampling_metadata: SamplingMetadata) -> torch.Tensor: |
| 128 | + ... |
| 129 | + |
| 130 | + def load_weights(self, |
| 131 | + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 132 | + ... |
| 133 | +``` |
| 134 | + |
| 135 | +### Method 2: Customizing Existing vLLM Models |
| 136 | + |
| 137 | +For most use cases, extending existing implementations is preferable. We demonstrate an example to inherit from base classes and implement a custom deepseek model below (assumed path: `vllm_ascend/models/deepseek_v2.py`). |
| 138 | + |
| 139 | +```python |
| 140 | +from typing import List, Optional |
| 141 | +import torch |
| 142 | +from vllm.attention import AttentionMetadata |
| 143 | +from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM |
| 144 | +from vllm.sequence import IntermediateTensors |
| 145 | + |
| 146 | +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): |
| 147 | + # Define merged weights for quantization/efficiency |
| 148 | + packed_modules_mapping = { |
| 149 | + "gate_up_proj": ["gate_proj", "up_proj"], |
| 150 | + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] |
| 151 | + } |
| 152 | + |
| 153 | + def forward( |
| 154 | + self, |
| 155 | + input_ids: torch.Tensor, |
| 156 | + positions: torch.Tensor, |
| 157 | + kv_caches: Optional[List[torch.Tensor]] = None, |
| 158 | + attn_metadata: Optional[AttentionMetadata] = None, |
| 159 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 160 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 161 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 162 | + # Custom forward logic |
| 163 | + hidden_states = self.model( |
| 164 | + input_ids, |
| 165 | + positions, |
| 166 | + kv_caches, |
| 167 | + attn_metadata, |
| 168 | + intermediate_tensors, |
| 169 | + inputs_embeds |
| 170 | + ) |
| 171 | + return hidden_states |
| 172 | +``` |
| 173 | + |
| 174 | +:::{note} |
| 175 | +For a complete implementation reference, see: `vllm_ascend/models/deepseek_v2.py`. |
| 176 | +::: |
| 177 | + |
| 178 | +## Step 2: Registering Custom Models using ModelRegistry Plugins in vLLM |
| 179 | + |
| 180 | +vllm provides a plugin mechanism for registering externally implemented models without modifying its codebase. |
| 181 | + |
| 182 | +To integrate your implemented model from `vllm_ascend/models/` directory: |
| 183 | + |
| 184 | +1. Import your model implementation in `vllm_ascend/models/__init__.py` using relative imports. |
| 185 | +2. Register the model wrapper class via `vllm.ModelRegistry.register_model()` function. |
| 186 | + |
| 187 | +**Reference Registration Template** (an example of registering new models in `vllm_ascend/models/__init__.py`): |
| 188 | + |
| 189 | +```python |
| 190 | +from vllm import ModelRegistry |
| 191 | + |
| 192 | +def register_model(): |
| 193 | + from .custom_model import CustomModelForCausalLM # New custom model |
| 194 | + from .deepseek_v2 import ModifiedDeepseekV2ForCausalLM # Customized Deepseek |
| 195 | + |
| 196 | + # For NEW architectures: Register with unique name |
| 197 | + ModelRegistry.register_model( |
| 198 | + "CustomModelForCausalLM", # Must match config.json's 'architectures' |
| 199 | + "vllm_ascend.models.custom_model:CustomModelForCausalLM" |
| 200 | + ) |
| 201 | + |
| 202 | + # For MODIFIED architectures: Use original name |
| 203 | + ModelRegistry.register_model( |
| 204 | + "DeepseekV2ForCausalLM", # Original architecture identifier in vLLM |
| 205 | + "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM " |
| 206 | + ) |
| 207 | +``` |
| 208 | + |
| 209 | +:::{note} |
| 210 | +The first argument of `vllm.ModelRegistry.register_model()` indicates the unique architecture identifier which must match `architectures` in `config.json` of the model. |
| 211 | + |
| 212 | +```json |
| 213 | +{ |
| 214 | + "architectures": [ |
| 215 | + "CustomModelForCausalLM" |
| 216 | + ], |
| 217 | +} |
| 218 | +``` |
| 219 | +::: |
| 220 | + |
| 221 | +## Step 3: Verification |
| 222 | + |
| 223 | +### Case 1: Overriding Existing vLLM Model Architecture |
| 224 | + |
| 225 | +If you're registering a customized model architecture based on vllm's existing implementation (overriding vllm's original class), when executing vllm offline/online inference (using any model), you'll observe warning logs similar to the following output from `vllm/models_executor/models/registry.py`. |
| 226 | + |
| 227 | +```bash |
| 228 | +Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend/models/deepseek_v2:CustomDeepseekV2ForCausalLM. |
| 229 | +``` |
| 230 | + |
| 231 | +### Case 2: Registering New Model Architecture |
| 232 | + |
| 233 | +If you're registering a novel model architecture not present in vllm (creating a completely new class), current logs won't provide explicit confirmation by default. It's recommended to add the following logging statement at the end of the `register_model` method in `vllm/models_executor/models/registry.py`. |
| 234 | + |
| 235 | +```python |
| 236 | +logger.info(f"model_arch: {model_arch} has been registered here!") |
| 237 | +``` |
| 238 | + |
| 239 | +After adding this line, you will see confirmation logs shown below when running vllm offline/online inference (using any model). |
| 240 | + |
| 241 | +```bash |
| 242 | +model_arch: CustomModelForCausalLM has been registered here! |
| 243 | +``` |
| 244 | + |
| 245 | +This log output confirms your novel model architecture has been successfully registered in vllm. |
| 246 | + |
| 247 | +## Step 4: Testing |
| 248 | + |
| 249 | +After adding a new model, we should do basic functional test (offline/online inference), accuracy test and performance benchmark for the model. |
| 250 | + |
| 251 | +Find more details at: |
| 252 | + |
| 253 | +- [Accuracy test guide](https://vllm-ascend.readthedocs.io/en/latest/developer_guide/evaluation/index.html) |
| 254 | +- [Performance benchmark guide](https://vllm-ascend.readthedocs.io/en/latest/developer_guide/performance/performance_benchmark.html) |
| 255 | + |
| 256 | +## Step 5: Updating Supported Models Doc |
| 257 | + |
| 258 | +At last, if all the steps above are completed, you should add the new model into our [Supported Models](https://vllm-ascend.readthedocs.io/en/latest/user_guide/supported_models.html) doc. |
0 commit comments