Skip to content

Commit 3687676

Browse files
[Doc] Add guidance on how to implement and register new models (#1426)
### What this PR does / why we need it? Add guidance on how to implement and register new models. Modified based on PR #1126, thanks for the contribution of @linfeng-yuan. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 5571fb7 commit 3687676

File tree

4 files changed

+272
-0
lines changed

4 files changed

+272
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Adding a New Model
2+
3+
This guide demonstrates how to integrate a novel or customized model into vllm-ascend. For foundational concepts, it is highly recommended to refer to
4+
[vllm official doc: Adding a New Model](https://docs.vllm.ai/en/stable/contributing/model/) first.
5+
6+
## Step 1: Implementing Models with `torch` and `torch_npu`
7+
8+
This section provides instructions for implementing new models compatible with vllm and vllm-ascend.
9+
10+
**Before starting:**
11+
12+
- Verify whether your model already exists in vllm's [models](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) directory.
13+
- Use existing models' implementation as templates to accelerate your development.
14+
15+
### Method 1: Implementing New Models from Scratch
16+
17+
Follow vllm's [OPT model adaptation](https://docs.vllm.ai/en/stable/contributing/model/basic.html) example for guidance.
18+
19+
**Key implementation requirements:**
20+
21+
1. Place model files in `vllm_ascend/models/` directory.
22+
23+
2. Standard module structure for decoder-only LLMs (please checkout vllm's implementations for other kinds of model):
24+
25+
- `*ModelForCausalLM` (top-level wrapper)
26+
- `*Model` (main architecture)
27+
- `*DecoderLayer` (transformer block)
28+
- `*Attention` and `*MLP` (specific computation unit)
29+
30+
:::{note}
31+
`*` denotes your model's unique identifier.
32+
:::
33+
34+
3. Critical Implementation Details:
35+
36+
All modules must include a `prefix` argument in `__init__()`.
37+
38+
**Required interfaces:**
39+
40+
| Module Type | Required Methods |
41+
| :------------------- | :---------------------------------------- |
42+
| `*ModelForCausalLM` | `get_input_embeddings`, `compute_logits`, `load_weights` |
43+
| `*Model` | `get_input_embeddings`, `load_weights` |
44+
45+
4. Attention Backend Integration:
46+
47+
Importing attention via `from vllm.attention import Attention` can automatically leverage the attention backend routing of vllm-ascend (see: `get_attn_backend_cls()` in `vllm_ascend/platform.py`).
48+
49+
5. Tensor Parallelism:
50+
51+
Use vllm's parallel layers (`ColumnParallelLinear`, `VocabParallelEmbedding`, etc.) to implement models supporting tensor parallelism. Note that Ascend-specific customizations are implemented in `vllm_ascend/ops/` directory (RMSNorm, VocabParallelEmbedding, etc.).
52+
53+
**Reference Implementation Template** (assumed path: `vllm_ascend/models/custom_model.py`):
54+
55+
```python
56+
from collections.abc import Iterable
57+
from typing import Optional, Union
58+
59+
import torch
60+
from torch import nn
61+
from vllm.attention import Attention
62+
from vllm.config import VllmConfig
63+
from vllm.sequence import IntermediateTensors
64+
from vllm.model_executor.sampling_metadata import SamplingMetadata
65+
66+
class CustomAttention(nn.Module):
67+
def __init__(self, vllm_config: VllmConfig, prefix: str):
68+
super().__init__()
69+
self.attn = Attention(prefix=f"{prefix}.attn")
70+
71+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
72+
# Implement attention logic
73+
...
74+
75+
class CustomDecoderLayer(nn.Module):
76+
def __init__(self, vllm_config: VllmConfig, prefix: str):
77+
super().__init__()
78+
self.self_attn = CustomAttention(vllm_config, prefix=f"{prefix}.self_attn")
79+
80+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
81+
# Implement decoder layer
82+
...
83+
84+
class CustomModel(nn.Module):
85+
def __init__(self, vllm_config: VllmConfig, prefix: str):
86+
super().__init__()
87+
self.layers = nn.ModuleList([
88+
CustomDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}")
89+
for i in range(vllm_config.model_config.hf_config.num_hidden_layers)
90+
])
91+
92+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
93+
...
94+
95+
def forward(
96+
self,
97+
input_ids: torch.Tensor,
98+
positions: torch.Tensor,
99+
intermediate_tensors: Optional[IntermediateTensors] = None,
100+
inputs_embeds: Optional[torch.Tensor] = None,
101+
) -> Union[torch.Tensor, IntermediateTensors]:
102+
...
103+
104+
def load_weights(self,
105+
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
106+
...
107+
108+
class CustomModelForCausalLM(nn.Module):
109+
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
110+
super().__init__()
111+
self.model = CustomModel(vllm_config, prefix=f"{prefix}.model")
112+
113+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
114+
...
115+
116+
def forward(
117+
self,
118+
input_ids: torch.Tensor,
119+
positions: torch.Tensor,
120+
intermediate_tensors: Optional[IntermediateTensors] = None,
121+
inputs_embeds: Optional[torch.Tensor] = None,
122+
) -> Union[torch.Tensor, IntermediateTensors]:
123+
...
124+
125+
def compute_logits(self,
126+
hidden_states: torch.Tensor,
127+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
128+
...
129+
130+
def load_weights(self,
131+
weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
132+
...
133+
```
134+
135+
### Method 2: Customizing Existing vLLM Models
136+
137+
For most use cases, extending existing implementations is preferable. We demonstrate an example to inherit from base classes and implement a custom deepseek model below (assumed path: `vllm_ascend/models/deepseek_v2.py`).
138+
139+
```python
140+
from typing import List, Optional
141+
import torch
142+
from vllm.attention import AttentionMetadata
143+
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM
144+
from vllm.sequence import IntermediateTensors
145+
146+
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
147+
# Define merged weights for quantization/efficiency
148+
packed_modules_mapping = {
149+
"gate_up_proj": ["gate_proj", "up_proj"],
150+
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
151+
}
152+
153+
def forward(
154+
self,
155+
input_ids: torch.Tensor,
156+
positions: torch.Tensor,
157+
kv_caches: Optional[List[torch.Tensor]] = None,
158+
attn_metadata: Optional[AttentionMetadata] = None,
159+
intermediate_tensors: Optional[IntermediateTensors] = None,
160+
inputs_embeds: Optional[torch.Tensor] = None,
161+
) -> Union[torch.Tensor, IntermediateTensors]:
162+
# Custom forward logic
163+
hidden_states = self.model(
164+
input_ids,
165+
positions,
166+
kv_caches,
167+
attn_metadata,
168+
intermediate_tensors,
169+
inputs_embeds
170+
)
171+
return hidden_states
172+
```
173+
174+
:::{note}
175+
For a complete implementation reference, see: `vllm_ascend/models/deepseek_v2.py`.
176+
:::
177+
178+
## Step 2: Registering Custom Models using ModelRegistry Plugins in vLLM
179+
180+
vllm provides a plugin mechanism for registering externally implemented models without modifying its codebase.
181+
182+
To integrate your implemented model from `vllm_ascend/models/` directory:
183+
184+
1. Import your model implementation in `vllm_ascend/models/__init__.py` using relative imports.
185+
2. Register the model wrapper class via `vllm.ModelRegistry.register_model()` function.
186+
187+
**Reference Registration Template** (an example of registering new models in `vllm_ascend/models/__init__.py`):
188+
189+
```python
190+
from vllm import ModelRegistry
191+
192+
def register_model():
193+
from .custom_model import CustomModelForCausalLM # New custom model
194+
from .deepseek_v2 import ModifiedDeepseekV2ForCausalLM # Customized Deepseek
195+
196+
# For NEW architectures: Register with unique name
197+
ModelRegistry.register_model(
198+
"CustomModelForCausalLM", # Must match config.json's 'architectures'
199+
"vllm_ascend.models.custom_model:CustomModelForCausalLM"
200+
)
201+
202+
# For MODIFIED architectures: Use original name
203+
ModelRegistry.register_model(
204+
"DeepseekV2ForCausalLM", # Original architecture identifier in vLLM
205+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM "
206+
)
207+
```
208+
209+
:::{note}
210+
The first argument of `vllm.ModelRegistry.register_model()` indicates the unique architecture identifier which must match `architectures` in `config.json` of the model.
211+
212+
```json
213+
{
214+
"architectures": [
215+
"CustomModelForCausalLM"
216+
],
217+
}
218+
```
219+
:::
220+
221+
## Step 3: Verification
222+
223+
### Case 1: Overriding Existing vLLM Model Architecture
224+
225+
If you're registering a customized model architecture based on vllm's existing implementation (overriding vllm's original class), when executing vllm offline/online inference (using any model), you'll observe warning logs similar to the following output from `vllm/models_executor/models/registry.py`.
226+
227+
```bash
228+
Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend/models/deepseek_v2:CustomDeepseekV2ForCausalLM.
229+
```
230+
231+
### Case 2: Registering New Model Architecture
232+
233+
If you're registering a novel model architecture not present in vllm (creating a completely new class), current logs won't provide explicit confirmation by default. It's recommended to add the following logging statement at the end of the `register_model` method in `vllm/models_executor/models/registry.py`.
234+
235+
```python
236+
logger.info(f"model_arch: {model_arch} has been registered here!")
237+
```
238+
239+
After adding this line, you will see confirmation logs shown below when running vllm offline/online inference (using any model).
240+
241+
```bash
242+
model_arch: CustomModelForCausalLM has been registered here!
243+
```
244+
245+
This log output confirms your novel model architecture has been successfully registered in vllm.
246+
247+
## Step 4: Testing
248+
249+
After adding a new model, we should do basic functional test (offline/online inference), accuracy test and performance benchmark for the model.
250+
251+
Find more details at:
252+
253+
- [Accuracy test guide](https://vllm-ascend.readthedocs.io/en/latest/developer_guide/evaluation/index.html)
254+
- [Performance benchmark guide](https://vllm-ascend.readthedocs.io/en/latest/developer_guide/performance/performance_benchmark.html)
255+
256+
## Step 5: Updating Supported Models Doc
257+
258+
At last, if all the steps above are completed, you should add the new model into our [Supported Models](https://vllm-ascend.readthedocs.io/en/latest/user_guide/supported_models.html) doc.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Adding a New Multi-Modal Model
2+
3+
**_Comming soon ..._**
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Modeling
2+
3+
This section provides tutorials of how to implement and register a new model into vllm-ascend.
4+
5+
:::{toctree}
6+
:caption: Modeling
7+
:maxdepth: 1
8+
adding_a_new_model
9+
adding_a_new_multimodal_model
10+
:::

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ developer_guide/contribution/index
6161
developer_guide/feature_guide/index
6262
developer_guide/evaluation/index
6363
developer_guide/performance/index
64+
developer_guide/modeling/index
6465
:::
6566

6667
% How to involve vLLM Ascend

0 commit comments

Comments
 (0)