|
67 | 67 |
|
68 | 68 | import vllm_ascend.envs as envs_ascend
|
69 | 69 | from vllm_ascend.ascend_config import get_ascend_config
|
| 70 | +from vllm_ascend.distributed.parallel_state import get_ep_group |
70 | 71 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
71 | 72 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
72 | 73 | from vllm_ascend.utils import dispose_tensor
|
@@ -211,13 +212,15 @@ def __init__(
|
211 | 212 |
|
212 | 213 | self.tp_group = get_tp_group().device_group
|
213 | 214 | self.tp_rank = get_tp_group().rank_in_group
|
| 215 | + self.ep_group = get_ep_group() |
214 | 216 |
|
215 | 217 | self.params_dtype = torch.get_default_dtype()
|
216 | 218 |
|
217 | 219 | ascend_config = get_ascend_config()
|
218 | 220 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 221 | + # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on |
219 | 222 | self.enable_multistream_shared_expert = \
|
220 |
| - ascend_config.torchair_graph_config.enable_multistream_shared_expert |
| 223 | + ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2 |
221 | 224 |
|
222 | 225 | def forward(
|
223 | 226 | self,
|
@@ -245,16 +248,12 @@ def forward(
|
245 | 248 | old_hidden_states = hidden_states.clone()
|
246 | 249 |
|
247 | 250 | if self.tp_size > 1:
|
248 |
| - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
249 |
| - chunks = torch.chunk(hidden_states, self.tp_size, dim=0) |
250 |
| - hidden_states = chunks[self.tp_rank] |
251 |
| - elif not self.torchair_graph_enabled: |
252 |
| - num_padding_tokens = (self.tp_size - |
253 |
| - num_tokens % self.tp_size) % self.tp_size |
254 |
| - # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C |
255 |
| - if num_padding_tokens > 0: |
| 251 | + if (VLLM_ENABLE_MC2 |
| 252 | + and not is_prefill) or not (self.torchair_graph_enabled or |
| 253 | + self.ep_group.world_size == 1): |
| 254 | + if num_tokens < self.tp_size: |
256 | 255 | hidden_states = nn.functional.pad(
|
257 |
| - hidden_states, (0, 0, 0, num_padding_tokens)) |
| 256 | + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) |
258 | 257 | chunk_hidden_states = torch.tensor_split(hidden_states,
|
259 | 258 | self.tp_size,
|
260 | 259 | dim=0)
|
@@ -284,24 +283,16 @@ def forward(
|
284 | 283 | hidden_states = hidden_states * self.routed_scaling_factor
|
285 | 284 |
|
286 | 285 | if self.tp_size > 1:
|
287 |
| - if self.torchair_graph_enabled: |
288 |
| - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
289 |
| - final_hidden_states = torch.zeros( |
290 |
| - [num_tokens, hidden_size], |
291 |
| - dtype=self.params_dtype, |
292 |
| - device="npu") |
293 |
| - dist.all_gather_into_tensor(final_hidden_states, |
294 |
| - hidden_states, self.tp_group) |
295 |
| - hidden_states = final_hidden_states |
296 |
| - else: |
297 |
| - hidden_states = tensor_model_parallel_all_reduce( |
298 |
| - hidden_states) |
299 |
| - else: |
| 286 | + if (VLLM_ENABLE_MC2 |
| 287 | + and not is_prefill) or not (self.torchair_graph_enabled or |
| 288 | + self.ep_group.world_size == 1): |
300 | 289 | dist.all_gather(list(chunk_hidden_states), hidden_states,
|
301 | 290 | self.tp_group)
|
302 | 291 | hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
303 |
| - if num_padding_tokens > 0: |
304 |
| - hidden_states = hidden_states[:-num_padding_tokens] |
| 292 | + if num_tokens < self.tp_size: |
| 293 | + hidden_states = hidden_states[:num_tokens] |
| 294 | + else: |
| 295 | + hidden_states = tensor_model_parallel_all_reduce(hidden_states) |
305 | 296 |
|
306 | 297 | if self.n_shared_experts is not None:
|
307 | 298 | if not multistream:
|
|
0 commit comments