Skip to content

Commit 9f5ae8a

Browse files
committed
Merge branch 'main' into dev
2 parents 87be51e + 00e0243 commit 9f5ae8a

File tree

10 files changed

+282
-93
lines changed

10 files changed

+282
-93
lines changed

docs/source/tutorials/multi_npu_quantization.md

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Multi-NPU (deepseek-v2-lite-w8a8)
1+
# Multi-NPU (QwQ 32B W8A8)
22

33
## Run docker container:
44
:::{note}
@@ -31,60 +31,54 @@ docker run --rm \
3131
## Install modelslim and convert model
3232
:::{note}
3333
You can choose to convert the model yourself or use the quantized model we uploaded,
34-
see https://www.modelscope.cn/models/vllm-ascend/DeepSeek-V2-Lite-w8a8
34+
see https://www.modelscope.cn/models/vllm-ascend/QwQ-32B-W8A8
3535
:::
3636

3737
```bash
38-
git clone https://gitee.com/ascend/msit
38+
# (Optional)This tag is recommended and has been verified
39+
git clone https://gitee.com/ascend/msit -b modelslim-VLLM-8.1.RC1.b020
3940

40-
# (Optional)This commit has been verified
41-
git checkout a396750f930e3bd2b8aa13730401dcbb4bc684ca
4241
cd msit/msmodelslim
4342
# Install by run this script
4443
bash install.sh
4544
pip install accelerate
4645

47-
cd /msit/msmodelslim/example/DeepSeek
46+
cd example/Qwen
4847
# Original weight path, Replace with your local model path
49-
MODEL_PATH=/home/weight/DeepSeek-V2-Lite
48+
MODEL_PATH=/home/models/QwQ-32B
5049
# Path to save converted weight, Replace with your local path
51-
SAVE_PATH=/home/weight/DeepSeek-V2-Lite-w8a8
52-
mkdir -p $SAVE_PATH
50+
SAVE_PATH=/home/models/QwQ-32B-w8a8
51+
5352
# In this conversion process, the npu device is not must, you can also set --device_type cpu to have a conversion
54-
python3 quant_deepseek.py --model_path $MODEL_PATH --save_directory $SAVE_PATH --device_type npu --act_method 2 --w_bit 8 --a_bit 8 --is_dynamic True
53+
python3 quant_qwen.py --model_path $MODEL_PATH --save_directory $SAVE_PATH --calib_file ../common/boolq.jsonl --w_bit 8 --a_bit 8 --device_type npu --anti_method m1 --trust_remote_code True
5554
```
5655

5756
## Verify the quantized model
5857
The converted model files looks like:
5958
```bash
6059
.
6160
|-- config.json
62-
|-- configuration_deepseek.py
63-
|-- fusion_result.json
61+
|-- configuration.json
6462
|-- generation_config.json
65-
|-- quant_model_description_w8a8_dynamic.json
66-
|-- quant_model_weight_w8a8_dynamic-00001-of-00004.safetensors
67-
|-- quant_model_weight_w8a8_dynamic-00002-of-00004.safetensors
68-
|-- quant_model_weight_w8a8_dynamic-00003-of-00004.safetensors
69-
|-- quant_model_weight_w8a8_dynamic-00004-of-00004.safetensors
70-
|-- quant_model_weight_w8a8_dynamic.safetensors.index.json
71-
|-- tokenization_deepseek_fast.py
63+
|-- quant_model_description.json
64+
|-- quant_model_weight_w8a8.safetensors
65+
|-- README.md
7266
|-- tokenizer.json
7367
`-- tokenizer_config.json
7468
```
7569

7670
Run the following script to start the vLLM server with quantize model:
7771
```bash
78-
vllm serve /home/weight/DeepSeek-V2-Lite-w8a8 --tensor-parallel-size 4 --trust-remote-code --served-model-name "dpsk-w8a8" --max-model-len 4096
72+
vllm serve /home/models/QwQ-32B-w8a8 --tensor-parallel-size 4 --served-model-name "qwq-32b-w8a8" --max-model-len 4096 --quantization ascend
7973
```
8074

8175
Once your server is started, you can query the model with input prompts
8276
```bash
8377
curl http://localhost:8000/v1/completions \
8478
-H "Content-Type: application/json" \
8579
-d '{
86-
"model": "dpsk-w8a8",
87-
"prompt": "what is deepseek?",
80+
"model": "qwq-32b-w8a8",
81+
"prompt": "what is large language model?",
8882
"max_tokens": "128",
8983
"top_p": "0.95",
9084
"top_k": "40",

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ addopts = --ignore=vllm-empty/tests/test_utils.py
6161
--ignore=vllm-empty/tests/detokenizer/test_stop_reason.py
6262
; oom on llama-2-7b-hf
6363
--ignore=vllm-empty/tests/detokenizer/test_stop_strings.py
64+
; no need to run on vllm-ascend
65+
--ignore=vllm-empty/tests/test_vllm_port.py
6466

6567
testpaths =
6668
vllm-empty/tests

vllm_ascend/attention/attention_v1.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import vllm_version_is
3334

3435

3536
class AscendAttentionBackend(AttentionBackend):
@@ -144,8 +145,15 @@ def build(self, num_reqs, num_actual_tokens, max_query_len,
144145
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
145146
query_start_loc = query_start_loc_cpu.to(self.runner.device,
146147
non_blocking=True)
147-
block_table = (
148-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
148+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
149+
block_table = (self.runner.input_batch.block_table.
150+
get_device_tensor()[:num_reqs])
151+
else:
152+
block_table = self.runner.input_batch.block_table[
153+
0].get_device_tensor()
154+
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
155+
block_table[:num_reqs])
156+
149157
query_lens = self.runner.query_lens
150158
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
151159
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(

vllm_ascend/attention/mla_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
19+
from vllm_ascend.utils import vllm_version_is
1920
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2021

2122
if TYPE_CHECKING:
@@ -238,8 +239,12 @@ def build(self,
238239
# function. We should avoid GPU -> CPU sync as much as possible because
239240
# it blocks on all previous kernels.
240241
device = self.runner.device
241-
block_table = (
242-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
242+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
243+
block_table = (self.runner.input_batch.block_table.
244+
get_device_tensor()[:num_reqs])
245+
else:
246+
block_table = (self.runner.input_batch.block_table[0].
247+
get_device_tensor()[:num_reqs])
243248
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
244249
device, non_blocking=True)
245250
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
@@ -795,4 +800,4 @@ def forward(
795800
output[:num_decode_tokens] = self._forward_decode(
796801
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
797802
kv_cache, attn_metadata)
798-
return output_padded
803+
return output_padded

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3737
from vllm.model_executor.models.qwen2_5_vl import (
3838
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
39-
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
40-
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
41-
Qwen2_5_VLProcessingInfo)
39+
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
40+
Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration,
41+
Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo)
4242
from vllm.model_executor.models.utils import maybe_prefix
4343
from vllm.multimodal import MULTIMODAL_REGISTRY
4444

@@ -152,6 +152,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152152
return x
153153

154154

155+
class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
156+
157+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
158+
super().__init__(dim, theta)
159+
inv_freq = 1.0 / (theta
160+
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
161+
self.inv_freq = inv_freq
162+
163+
155164
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
156165

157166
def __init__(
@@ -166,6 +175,9 @@ def __init__(
166175
norm_layer = partial(RMSNorm, eps=norm_eps)
167176
self.interleaved = interleaved
168177
self.enable_pad = False
178+
head_dim = self.hidden_size // self.num_heads
179+
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
180+
2)
169181
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
170182
patch_size=vision_config.patch_size,
171183
temporal_patch_size=vision_config.temporal_patch_size,
@@ -298,6 +310,66 @@ def load_weights(self, weights: Iterable[Tuple[str,
298310
loaded_params.add(name)
299311
return loaded_params
300312

313+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
314+
pos_ids = []
315+
for t, h, w in grid_thw:
316+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
317+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
318+
hpos_ids = hpos_ids.reshape(
319+
h // self.spatial_merge_size,
320+
self.spatial_merge_size,
321+
w // self.spatial_merge_size,
322+
self.spatial_merge_size,
323+
).permute(0, 2, 1, 3).flatten()
324+
wpos_ids = wpos_ids.reshape(
325+
h // self.spatial_merge_size,
326+
self.spatial_merge_size,
327+
w // self.spatial_merge_size,
328+
self.spatial_merge_size,
329+
).permute(0, 2, 1, 3).flatten()
330+
pos_ids.append(
331+
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
332+
pos_ids = torch.cat(pos_ids, dim=0)
333+
max_grid_size = grid_thw[:, 1:].max()
334+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
335+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
336+
return rotary_pos_emb
337+
338+
def get_window_index(self, grid_thw):
339+
window_index: list = []
340+
cu_window_seqlens: list = [0]
341+
window_index_id = 0
342+
vit_merger_window_size = (self.window_size //
343+
self.spatial_merge_size // self.patch_size)
344+
345+
for grid_t, grid_h, grid_w in grid_thw:
346+
llm_grid_h = grid_h // self.spatial_merge_size
347+
llm_grid_w = grid_w // self.spatial_merge_size
348+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
349+
grid_t, llm_grid_h, llm_grid_w)
350+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
351+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
352+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
353+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
354+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
355+
index_padded = index_padded.reshape(grid_t, num_windows_h,
356+
vit_merger_window_size,
357+
num_windows_w,
358+
vit_merger_window_size)
359+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
360+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
361+
vit_merger_window_size)
362+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
363+
index_padded = index_padded.reshape(-1)
364+
index_new = index_padded[index_padded != -100]
365+
window_index.append(index_new + window_index_id)
366+
cu_seqlens_tmp = seqlens.cumsum(
367+
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
368+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
369+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
370+
window_index = torch.cat(window_index, dim=0)
371+
return window_index, cu_window_seqlens
372+
301373
def forward(
302374
self,
303375
x: torch.Tensor,
@@ -366,4 +438,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366438
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
367439
quant_config=self._maybe_ignore_quant_config(quant_config),
368440
prefix=maybe_prefix(prefix, "visual"),
369-
)
441+
)
442+
443+
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
444+
445+
grid_thw = image_input["image_grid_thw"]
446+
assert grid_thw.ndim == 2
447+
448+
if image_input["type"] == "image_embeds":
449+
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
450+
else:
451+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
452+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
453+
454+
# Split concatenated embeddings for each image item.
455+
merge_size = self.visual.spatial_merge_size
456+
sizes = grid_thw.prod(-1) // merge_size // merge_size
457+
return image_embeds.split(sizes.tolist())
458+
459+
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
460+
461+
grid_thw = video_input["video_grid_thw"]
462+
assert grid_thw.ndim == 2
463+
464+
if video_input["type"] == "video_embeds":
465+
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
466+
else:
467+
pixel_values_videos = video_input["pixel_values_videos"].type(
468+
self.visual.dtype)
469+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
470+
471+
# Split concatenated embeddings for each video item.
472+
merge_size = self.visual.spatial_merge_size
473+
sizes = grid_thw.prod(-1) // merge_size // merge_size
474+
return video_embeds.split(sizes.tolist())

0 commit comments

Comments
 (0)