Skip to content

[bugfix] Add ep initialization check and change the return check to is_driver_worker #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def test_mtp_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
del spec_llm
3 changes: 3 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class AscendMetadata:
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.


class AscendAttentionMetadataBuilder:

Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
return _ETP


def model_parallel_initialized():
return (_ETP is not None and _EP is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use ep without etp, thus this will break this senario

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. If ETP is not enabled, communication groups will still be created.



def init_ascend_model_parallel(
expert_parallel_size: int = 1,
expert_tensor_parallel_size: int = 1,
world_size: Optional[int] = None,
backend: Optional[str] = None,
):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = world_size or torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def fused_experts_with_mc2(
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)

world_szie = torch.distributed.get_world_size()
tp_size = world_szie // all_to_all_group_size
tp_size = get_etp_group().world_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo please double check this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine

tp_rank = rank % tp_size

stage1_kwargs = {
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/patch/platform/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import vllm
import vllm.distributed
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
Expand Down Expand Up @@ -164,10 +165,9 @@ def parallel_config_get_dp_port(self) -> int:
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
import os

# NOTE: Get port from envs directly when using torchrun
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
return port


Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/mtp_proposer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,

global_indices_flat = global_indices[mask]
values_flat = values[mask]
out_ptr[global_indices_flat] = values_flat
out_ptr[global_indices_flat] = values_flat
2 changes: 1 addition & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
return output if self.is_driver_worker else None

def load_model(self) -> None:
self.model_runner.load_model()
Expand Down
Loading