Skip to content

[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE #20166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 8, 2025

Conversation

minosfuture
Copy link
Contributor

@minosfuture minosfuture commented Jun 27, 2025

Purpose

This PR fixes the following error when starting EP on Maverick:

(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]     run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]   File "/home/yeq/gitrepos/vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py", line 89, in run_cutlass_moe_fp8
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]     local_topk_ids = torch.where(expert_map[topk_ids] != -1,
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527]                                  ~~~~~~~~~~^^^^^^^^^^
(VllmWorker rank=3 pid=1737537) ERROR 06-15 22:58:28 [multiproc_executor.py:527] IndexError: tensors used as indices must be long, int, byte or bool tensors

In the PPLX implementation #18762, the dtype got flipped to uint32, here.

Besides this fix, the workspace_shapes needed another fix here from #19168, which is already merged; otherwise, the torch.zeros is slow for processing much larger size of data here.

Test Plan

  1. benchmark for latency sanity
# serve
vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
        --max_model_len 8192 \
        --kv_cache_dtype fp8 \
        --enable-expert-parallel \
        --tensor-parallel-size 8 \
        --trust-remote-code \
        --enforce_eager \
        --gpu-memory-utilization 0.8 \
        --disable-log-requests 2>&1 | tee ep_`date +%Y%m%d_%H%M%S`.log
# benchmark serve
python benchmarks/benchmark_serving.py  --model meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 \
        --port 8000  --dataset-name random  --ignore-eos  --num-prompts 500   --max-concurrency 128 \
        --random-input-len 2000 --random-output-len 150
  1. lm_eval

Test Result

runtime exception during init is fixed. Attaching benchmark results:

============ Serving Benchmark Result ============
Successful requests:                     500
Benchmark duration (s):                  42.34
Total input tokens:                      998815
Total generated tokens:                  75000
Request throughput (req/s):              11.81
Output token throughput (tok/s):         1771.43
Total Token throughput (tok/s):          25362.46
---------------Time to First Token----------------
Mean TTFT (ms):                          1119.22
Median TTFT (ms):                        384.95
P99 TTFT (ms):                           5939.92
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          63.31
Median TPOT (ms):                        66.14
P99 TPOT (ms):                           67.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           63.31
Median ITL (ms):                         33.70
P99 ITL (ms):                            198.63
==================================================

lm_eval results:

local-chat-completions (model=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,base_url=http://127.0.0.1:8081/v1/chat/completions,num_concurrent=32), gen_kwargs: (None), limit: 200.0, num_fewshot: 5, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.935 ± 0.0175
strict-match 5 exact_match 0.920 ± 0.0192

with cuda graph:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.940 ± 0.0168
strict-match 5 exact_match 0.925 ± 0.0187

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @minosfuture, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request provides a crucial bugfix for the Mixture of Experts (MoE) implementation, specifically for FP8 quantization with CUTLASS. It resolves a runtime IndexError that prevented the successful initialization and execution of models utilizing this configuration, ensuring the stability and functionality of FP8 MoE operations.

Highlights

  • Bugfix: MoE FP8 Indexing: This pull request addresses a critical IndexError occurring during the execution of CUTLASS w8a8 FP8 MoE (Mixture of Experts) operations. The error stemmed from topk_ids tensors being incorrectly cast to torch.uint32, which is not a valid type for indexing in PyTorch.
  • Code Correction: The fix involves removing the explicit indices_type=torch.uint32 argument from the apply function call within the fused_experts initialization in compressed_tensors_moe.py. This allows the system to use the correct default or inferred integer type (e.g., torch.long) for indexing, resolving the runtime crash.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a runtime IndexError that occurred during MoE execution with CUTLASS. The error was caused by topk_ids having an unsupported uint32 dtype for indexing. The fix, which removes the indices_type=torch.uint32 argument from the routing function call, is direct and effective, allowing the topk_ids tensor to default to a valid type for indexing. The change is well-supported by the provided error log and test results.

@yeqcharlotte
Copy link
Collaborator

thanks for the fix! could you also share the eval result? has cudagraph worked it?

cc: @ElizaWszola @bnellnm to take a look!

@minosfuture
Copy link
Contributor Author

minosfuture commented Jun 27, 2025

thanks for the fix! could you also share the eval result? has cudagraph worked it?

cc: @ElizaWszola @bnellnm to take a look!

updated with lm-eval results. Note that it's tested with the correctness fix #20167. Yes, both eager and cuda graph work.

@ElizaWszola
Copy link
Contributor

Thanks for the fix! Can you please check if the kernels in csrc/quantization/cutlass_w8a8/moe/moe_data.cu that use uint32_t topk_ids will still work and compile without complaints if you change the types to int32_t and update your pr to use int32_t in these functions? If this breaks the kernels, it would be good to have an explicit conversions to uint32_t when we want to call them.

@minosfuture
Copy link
Contributor Author

Thanks for the fix! Can you please check if the kernels in csrc/quantization/cutlass_w8a8/moe/moe_data.cu that use uint32_t topk_ids will still work and compile without complaints if you change the types to int32_t and update your pr to use int32_t in these functions? If this breaks the kernels, it would be good to have an explicit conversions to uint32_t when we want to call them.

updated.

In PplxPrepareAndFinalize and DeepEPLLPrepareAndFinalize, topk_indices_dtype returns uint32 and int64, respectively. I suggest we change them to int32 for consistency? Keeping them as is shouldn't result in casting error though given that ids should be within a small range from zero.

@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Jul 3, 2025

Thanks for the fix! Can you please check if the kernels in csrc/quantization/cutlass_w8a8/moe/moe_data.cu that use uint32_t topk_ids will still work and compile without complaints if you change the types to int32_t and update your pr to use int32_t in these functions? If this breaks the kernels, it would be good to have an explicit conversions to uint32_t when we want to call them.

updated.

In PplxPrepareAndFinalize and DeepEPLLPrepareAndFinalize, topk_indices_dtype returns uint32 and int64, respectively. I suggest we change them to int32 for consistency? Keeping them as is shouldn't result in casting error though given that ids should be within a small range from zero.

+1 to changing topk-ids from uint32 to int32 for consistency.

The topk-ids for pplx-kernels are uint32_t, so we should take care to avoid issues there.

I agree that we don't need to add an explicit cast. However let's make sure there's a comment warning about this in pplx-prepare-finalize and an assert that there's no expert map, since that -1 is used to signify "not this rank's token" in that case.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once we have some topk_id warning labels around the int32/uint32 danger in fused_moe/pplx_prepare_finalize.py

Copy link

mergify bot commented Jul 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @minosfuture.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 3, 2025
@tlrmchlsmth tlrmchlsmth closed this Jul 3, 2025
@tlrmchlsmth tlrmchlsmth reopened this Jul 3, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll also need to change the pplx_prepare_finalize to return int32_t instead

@tlrmchlsmth
Copy link
Collaborator

@bnellnm see my comments above. Do you see any issues there?

@mgoin mgoin mentioned this pull request Jul 3, 2025
@bnellnm
Copy link
Contributor

bnellnm commented Jul 3, 2025

@bnellnm see my comments above. Do you see any issues there?

I guess that would be the simplest thing. I agree that we'll need to be careful with pplx.

@yeqcharlotte
Copy link
Collaborator

simplest

then let’s run some e2e eval on deepseek r1 with pplx? could you share the setup?

Signed-off-by: Ming Yang <yming@meta.com>
Signed-off-by: Ming Yang <yming@meta.com>
Signed-off-by: Ming Yang <yming@meta.com>
@mergify mergify bot removed the needs-rebase label Jul 3, 2025
Signed-off-by: Ming Yang <yming@meta.com>
# assert expert_map is None, "NYI"
assert expert_map is None, """with expert map, -1 id is used for
non-local token; this causes error when casting ids to the
topk_indices_dtype() uint32"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assertion and changed the id type here. test_pplx_moe.py passes. Let me know if we should run a model e2e and how. Thx! @tylertitsworth @bnellnm cc @yeqcharlotte

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I don't think we need both the assertion and int32 change?
I don't know enough internals of pplx comms to decide this. Let me know how we wanna finalize this change, @tlrmchlsmth @bnellnm. Thanks!

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @minosfuture Could we lift this assert ? I think PrepareAndFinalize implementations should allow callers to pass in a "valid" expert_map.

I believe a better solution is to make expert_map=None inside this function with comment with expert map, -1 id is used for non-local token; this causes error when casting ids to the topk_indices_dtype() uint32 to prevent issues with pplx + expert_map.

fyi @tlrmchlsmth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me raise a PR for this! thanks for the catch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be problematic to clear expert_map assigning None to it?
do you know why an expert_map is passed in in the first place?

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be problematic to clear expert_map assigning None to it?

The function doesn't use the expert_map - so overriding it to None should be fine as it'd prevent any incorrect use.

do you know why an expert_map is passed in in the first place?

I actually dont see it being used in any of the implementations. I think it is okay to remove the argument altogether.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameter is needed for this is an inherited function.

I commented it all out. This essentially reverts the changes here. As a followup, I think we should fix where it passes an expert_map, and add the assertion back.

#20714 @varun-sundar-rabindranath @tlrmchlsmth pls help approve this fix pr and auto-merge it. thx!

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameter is needed for this is an inherited function.

I was suggesting we could remove the var in the base-class as I don't see any implementation using it.
I approved it, we can do the cleanups in a followup. Thanks 🙌

a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a fix needed after rebase #19636. cc @bnellnm @luccafong

@facebook-github-bot
Copy link

This pull request has been imported. If you are a Meta employee, you can view this in D77811621.

return torch.int64
return torch.int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concerned this will break things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this id a small number for deepep_ll?
Can I run some tests to confirm it's safe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also restore type changes for pplx and deepep_ll in this PR, and work on it in a new one.
Hoping to get this PR in to unblock maverick devs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep -- I'm hitting:

RuntimeError: Failed: Assertion error /app/DeepEP/csrc/deep_ep.cpp:1030 'topk_idx.scalar_type() == torch::kInt64'

Let's revert this line, and otherwise lgtm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! Pls help trigger auto merge. Thanks for reviewing both PRs!

Signed-off-by: Ming Yang <yming@meta.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 8, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 8, 2025 12:34
@tlrmchlsmth tlrmchlsmth merged commit c438183 into vllm-project:main Jul 8, 2025
107 checks passed
@@ -81,7 +81,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens

def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.uint32
return torch.int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note @tlrmchlsmth @varun-sundar-rabindranath - this appears to have broken the PPLX backend as the pplx dispatch function expects uint32

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM_ALL2ALL_BACKEND="pplx" vllm serve Qwen/Qwen3-30B-A3B-FP8 --data-parallel-size 2 --enable-expert-parallel --enforce-eager

result:

(EngineCore_0 pid=1149186) ERROR 07-12 19:10:48 [core.py:586] RuntimeError: indices must be of type UInt32
(EngineCore_0 pid=1149186) Process EngineCore_0:
(EngineCore_0 pid=1149186) Traceback (most recent call last):
(EngineCore_0 pid=1149186)   File "/home/rshaw/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_0 pid=1149186)     self.run()
(EngineCore_0 pid=1149186)   File "/home/rshaw/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_0 pid=1149186)     self._target(*self._args, **self._kwargs)
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 590, in run_engine_core
(EngineCore_0 pid=1149186)     raise e
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 575, in run_engine_core
(EngineCore_0 pid=1149186)     engine_core = DPEngineCoreProc(*args, **kwargs)
(EngineCore_0 pid=1149186)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 835, in __init__
(EngineCore_0 pid=1149186)     super().__init__(vllm_config, local_client, handshake_address,
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 404, in __init__
(EngineCore_0 pid=1149186)     super().__init__(vllm_config, executor_class, log_stats,
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 82, in __init__
(EngineCore_0 pid=1149186)     self._initialize_kv_caches(vllm_config)
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
(EngineCore_0 pid=1149186)     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore_0 pid=1149186)                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
(EngineCore_0 pid=1149186)     output = self.collective_rpc("determine_available_memory")
(EngineCore_0 pid=1149186)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
(EngineCore_0 pid=1149186)     answer = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_0 pid=1149186)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/utils/__init__.py", line 2955, in run_method
(EngineCore_0 pid=1149186)     return func(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(EngineCore_0 pid=1149186)     return func(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/worker/gpu_worker.py", line 219, in determine_available_memory
(EngineCore_0 pid=1149186)     self.model_runner.profile_run()
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/worker/gpu_model_runner.py", line 2239, in profile_run
(EngineCore_0 pid=1149186)     = self._dummy_run(self.max_num_tokens, is_profile=True)
(EngineCore_0 pid=1149186)       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(EngineCore_0 pid=1149186)     return func(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/v1/worker/gpu_model_runner.py", line 2020, in _dummy_run
(EngineCore_0 pid=1149186)     outputs = model(
(EngineCore_0 pid=1149186)               ^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186)     return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186)     return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 529, in forward
(EngineCore_0 pid=1149186)     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(EngineCore_0 pid=1149186)                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/compilation/decorators.py", line 173, in __call__
(EngineCore_0 pid=1149186)     return self.forward(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 369, in forward
(EngineCore_0 pid=1149186)     hidden_states, residual = layer(positions, hidden_states, residual)
(EngineCore_0 pid=1149186)                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186)     return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186)     return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 313, in forward
(EngineCore_0 pid=1149186)     hidden_states = self.mlp(hidden_states)
(EngineCore_0 pid=1149186)                     ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186)     return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186)     return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 136, in forward
(EngineCore_0 pid=1149186)     final_hidden_states = self.experts(hidden_states=hidden_states,
(EngineCore_0 pid=1149186)                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186)     return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186)     return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1381, in forward
(EngineCore_0 pid=1149186)     return torch.ops.vllm.moe_forward(hidden_states, router_logits,
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(EngineCore_0 pid=1149186)     return self._op(*args, **(kwargs or {}))
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1566, in moe_forward
(EngineCore_0 pid=1149186)     return self.forward_impl(hidden_states, router_logits)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1465, in forward_impl
(EngineCore_0 pid=1149186)     return self.forward_impl_chunked(hidden_states, router_logits)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1454, in forward_impl_chunked
(EngineCore_0 pid=1149186)     process_chunk(chunk_start,
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1415, in process_chunk
(EngineCore_0 pid=1149186)     final_hidden_states = self.quant_method.apply(
(EngineCore_0 pid=1149186)                           ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/quantization/fp8.py", line 970, in apply
(EngineCore_0 pid=1149186)     return self.fused_experts(
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186)     return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186)     return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 689, in forward
(EngineCore_0 pid=1149186)     _expert_topk_weights) = self.prepare_finalize.prepare(
(EngineCore_0 pid=1149186)                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py", line 202, in prepare
(EngineCore_0 pid=1149186)     self.a2a.dispatch(
(EngineCore_0 pid=1149186)   File "/home/rshaw/vllm/tools/ep_kernels/ep_kernels_workspace/pplx-kernels/src/pplx_kernels/all_to_all.py", line 48, in dispatch
(EngineCore_0 pid=1149186)     self._dispatch_fn(
(EngineCore_0 pid=1149186)   File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(EngineCore_0 pid=1149186)     return self._op(*args, **(kwargs or {}))
(EngineCore_0 pid=1149186)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) RuntimeError: indices must be of type UInt32

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat that will be fixed by #20825

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay cool - just a note: I was using triton+pplx in this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the regression! I'll make sure to test it next time. We are also working on improving CI coverage and preparing a script to cover major models for easy local test.

VLLM_ALL2ALL_BACKEND="pplx" vllm serve Qwen/Qwen3-30B-A3B-FP8 --data-parallel-size 2 --enable-expert-parallel --enforce-eager

Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…ject#20166)

Signed-off-by: Ming Yang <yming@meta.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants