-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE #20166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6d0e19a
7c57bb0
376bcde
6ca83a4
ec668c3
beeba9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: | |
return self.max_num_tokens | ||
|
||
def topk_indices_dtype(self) -> Optional[torch.dtype]: | ||
return torch.uint32 | ||
return torch.int32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note @tlrmchlsmth @varun-sundar-rabindranath - this appears to have broken the PPLX backend as the pplx dispatch function expects uint32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VLLM_ALL2ALL_BACKEND="pplx" vllm serve Qwen/Qwen3-30B-A3B-FP8 --data-parallel-size 2 --enable-expert-parallel --enforce-eager result: (EngineCore_0 pid=1149186) ERROR 07-12 19:10:48 [core.py:586] RuntimeError: indices must be of type UInt32
(EngineCore_0 pid=1149186) Process EngineCore_0:
(EngineCore_0 pid=1149186) Traceback (most recent call last):
(EngineCore_0 pid=1149186) File "/home/rshaw/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_0 pid=1149186) self.run()
(EngineCore_0 pid=1149186) File "/home/rshaw/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_0 pid=1149186) self._target(*self._args, **self._kwargs)
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 590, in run_engine_core
(EngineCore_0 pid=1149186) raise e
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 575, in run_engine_core
(EngineCore_0 pid=1149186) engine_core = DPEngineCoreProc(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 835, in __init__
(EngineCore_0 pid=1149186) super().__init__(vllm_config, local_client, handshake_address,
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 404, in __init__
(EngineCore_0 pid=1149186) super().__init__(vllm_config, executor_class, log_stats,
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 82, in __init__
(EngineCore_0 pid=1149186) self._initialize_kv_caches(vllm_config)
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
(EngineCore_0 pid=1149186) available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
(EngineCore_0 pid=1149186) output = self.collective_rpc("determine_available_memory")
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
(EngineCore_0 pid=1149186) answer = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/utils/__init__.py", line 2955, in run_method
(EngineCore_0 pid=1149186) return func(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(EngineCore_0 pid=1149186) return func(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/worker/gpu_worker.py", line 219, in determine_available_memory
(EngineCore_0 pid=1149186) self.model_runner.profile_run()
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/worker/gpu_model_runner.py", line 2239, in profile_run
(EngineCore_0 pid=1149186) = self._dummy_run(self.max_num_tokens, is_profile=True)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(EngineCore_0 pid=1149186) return func(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/v1/worker/gpu_model_runner.py", line 2020, in _dummy_run
(EngineCore_0 pid=1149186) outputs = model(
(EngineCore_0 pid=1149186) ^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186) return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186) return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 529, in forward
(EngineCore_0 pid=1149186) hidden_states = self.model(input_ids, positions, intermediate_tensors,
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/compilation/decorators.py", line 173, in __call__
(EngineCore_0 pid=1149186) return self.forward(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 369, in forward
(EngineCore_0 pid=1149186) hidden_states, residual = layer(positions, hidden_states, residual)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186) return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186) return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 313, in forward
(EngineCore_0 pid=1149186) hidden_states = self.mlp(hidden_states)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186) return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186) return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/models/qwen3_moe.py", line 136, in forward
(EngineCore_0 pid=1149186) final_hidden_states = self.experts(hidden_states=hidden_states,
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186) return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186) return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1381, in forward
(EngineCore_0 pid=1149186) return torch.ops.vllm.moe_forward(hidden_states, router_logits,
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(EngineCore_0 pid=1149186) return self._op(*args, **(kwargs or {}))
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1566, in moe_forward
(EngineCore_0 pid=1149186) return self.forward_impl(hidden_states, router_logits)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1465, in forward_impl
(EngineCore_0 pid=1149186) return self.forward_impl_chunked(hidden_states, router_logits)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1454, in forward_impl_chunked
(EngineCore_0 pid=1149186) process_chunk(chunk_start,
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1415, in process_chunk
(EngineCore_0 pid=1149186) final_hidden_states = self.quant_method.apply(
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/quantization/fp8.py", line 970, in apply
(EngineCore_0 pid=1149186) return self.fused_experts(
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_0 pid=1149186) return self._call_impl(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_0 pid=1149186) return forward_call(*args, **kwargs)
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 689, in forward
(EngineCore_0 pid=1149186) _expert_topk_weights) = self.prepare_finalize.prepare(
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py", line 202, in prepare
(EngineCore_0 pid=1149186) self.a2a.dispatch(
(EngineCore_0 pid=1149186) File "/home/rshaw/vllm/tools/ep_kernels/ep_kernels_workspace/pplx-kernels/src/pplx_kernels/all_to_all.py", line 48, in dispatch
(EngineCore_0 pid=1149186) self._dispatch_fn(
(EngineCore_0 pid=1149186) File "/home/rshaw/.vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(EngineCore_0 pid=1149186) return self._op(*args, **(kwargs or {}))
(EngineCore_0 pid=1149186) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=1149186) RuntimeError: indices must be of type UInt32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-redhat that will be fixed by #20825 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay cool - just a note: I was using triton+pplx in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the regression! I'll make sure to test it next time. We are also working on improving CI coverage and preparing a script to cover major models for easy local test.
|
||
|
||
def prepare( | ||
self, | ||
|
@@ -100,7 +100,9 @@ def prepare( | |
hidden_dim = a1.size(-1) # K | ||
|
||
assert topk_ids.size(0) == num_tokens | ||
# assert expert_map is None, "NYI" | ||
assert expert_map is None, """with expert map, -1 id is used for | ||
non-local token; this causes error when casting ids to the | ||
topk_indices_dtype() uint32""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added assertion and changed the id type here. test_pplx_moe.py passes. Let me know if we should run a model e2e and how. Thx! @tylertitsworth @bnellnm cc @yeqcharlotte There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at this again, I don't think we need both the assertion and int32 change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @minosfuture Could we lift this assert ? I think
I believe a better solution is to make fyi @tlrmchlsmth There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me raise a PR for this! thanks for the catch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be problematic to clear expert_map assigning None to it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The function doesn't use the expert_map - so overriding it to None should be fine as it'd prevent any incorrect use.
I actually dont see it being used in any of the implementations. I think it is okay to remove the argument altogether. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. parameter is needed for this is an inherited function. I commented it all out. This essentially reverts the changes here. As a followup, I think we should fix where it passes an expert_map, and add the assertion back. #20714 @varun-sundar-rabindranath @tlrmchlsmth pls help approve this fix pr and auto-merge it. thx! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I was suggesting we could remove the var in the base-class as I don't see any implementation using it. |
||
|
||
# Is this always going to be a1.device? | ||
device = a1.device | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -883,23 +883,27 @@ def apply( | |
num_expert_group=num_expert_group, | ||
custom_routing_function=custom_routing_function, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias, | ||
indices_type=self.topk_indices_dtype, | ||
) | ||
e_score_correction_bias=e_score_correction_bias) | ||
|
||
a1_scale = layer.w13_input_scale | ||
a2_scale = layer.w2_input_scale | ||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( | ||
a2_scale.numel() != 1 if a2_scale is not None else False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a fix needed after rebase #19636. cc @bnellnm @luccafong |
||
|
||
return self.fused_experts( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights, | ||
topk_ids, | ||
per_act_token=per_act_token, | ||
activation=activation, | ||
global_num_experts=global_num_experts, | ||
expert_map=None if self.disable_expert_map else expert_map, | ||
w1_scale=layer.w13_weight_scale, | ||
w2_scale=layer.w2_weight_scale, | ||
a1_scale=layer.w13_input_scale, | ||
a2_scale=layer.w2_input_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
concerned this will break things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this id a small number for deepep_ll?
Can I run some tests to confirm it's safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also restore type changes for pplx and deepep_ll in this PR, and work on it in a new one.
Hoping to get this PR in to unblock maverick devs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep -- I'm hitting:
RuntimeError: Failed: Assertion error /app/DeepEP/csrc/deep_ep.cpp:1030 'topk_idx.scalar_type() == torch::kInt64'
Let's revert this line, and otherwise lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated! Pls help trigger auto merge. Thanks for reviewing both PRs!