Skip to content

Commit 5eeadc2

Browse files
[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral (vllm-project#12303)
Signed-off-by: zhenwei <zhenweiliu@habana.ai>
1 parent 3aee657 commit 5eeadc2

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,34 @@ def forward_cpu(
213213
e_score_correction_bias,
214214
)
215215

216+
def forward_hpu(
217+
self,
218+
layer: torch.nn.Module,
219+
x: torch.Tensor,
220+
use_grouped_topk: bool,
221+
top_k: int,
222+
router_logits: torch.Tensor,
223+
renormalize: bool,
224+
topk_group: Optional[int] = None,
225+
num_expert_group: Optional[int] = None,
226+
custom_routing_function: Optional[Callable] = None,
227+
scoring_func: str = "softmax",
228+
e_score_correction_bias: Optional[torch.Tensor] = None
229+
) -> torch.Tensor:
230+
assert not use_grouped_topk
231+
assert num_expert_group is None
232+
assert topk_group is None
233+
assert custom_routing_function is None
234+
assert layer is not None
235+
if scoring_func != "softmax":
236+
raise NotImplementedError(
237+
"Only softmax scoring function is supported for HPU.")
238+
if e_score_correction_bias is not None:
239+
raise NotImplementedError(
240+
"Expert score correction bias is not supported for HPU.")
241+
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
242+
router_logits, top_k)
243+
216244
def forward_tpu(
217245
self,
218246
layer: torch.nn.Module,
@@ -411,6 +439,9 @@ def __init__(
411439
if self.scoring_func != "softmax" and not self.use_grouped_topk:
412440
raise ValueError("Only softmax scoring function is supported for "
413441
"non-grouped topk.")
442+
if current_platform.is_hpu():
443+
from vllm_hpu_extension.ops import DynamicFusedMOE
444+
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
414445

415446
# Note: get_quant_method will look at the layer's local_num_experts
416447
# for heuristic purposes, so it must be initialized first.

vllm/model_executor/model_loader/loader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,16 @@ def _xla_weights_iterator(iterator: Generator):
387387

388388
weights_iterator = _xla_weights_iterator(weights_iterator)
389389

390+
elif current_platform.is_hpu():
391+
import habana_frameworks.torch.core as htcore
392+
393+
def _hpu_weights_iterator(iterator: Generator):
394+
for weights in iterator:
395+
yield weights
396+
htcore.mark_step()
397+
398+
weights_iterator = _hpu_weights_iterator(weights_iterator)
399+
390400
if self.counter_before_loading_weights == 0.0:
391401
self.counter_before_loading_weights = time.perf_counter()
392402
# Apply the prefix.

vllm/worker/hpu_model_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,22 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
376376
mask = mask >= metadata.block_usage.unsqueeze(-1)
377377
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
378378
mask, -math.inf))
379-
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
380-
num_classes=batch_size)
379+
if os.environ.get('VLLM_USE_FAKE_HPU',
380+
'0') == '0' and htorch.utils.internal.is_lazy():
381+
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
382+
num_classes=batch_size)
383+
else:
384+
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
385+
# doesn't handle out of bounds classes so we need to convert
386+
# all negative values to 0 (block_mapping) or bs (block_groups)
387+
block_groups = metadata.block_groups.to(torch.long)
388+
block_mapping = torch.nn.functional.relu(block_groups)
389+
block_mapping = torch.nn.functional.one_hot(block_mapping,
390+
num_classes=batch_size)
391+
oob_values = block_groups.lt(0)
392+
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
393+
block_groups.masked_fill_(oob_values, batch_size)
394+
metadata = metadata._replace(block_groups=block_groups)
381395
block_mapping = block_mapping.to(dtype)
382396
metadata = metadata._replace(block_mapping=block_mapping,
383397
attn_bias=attn_bias)

0 commit comments

Comments
 (0)