Skip to content

Commit 37a7d5d

Browse files
ilmarkovilmarkov
andauthored
[Misc] Refactor AllReduceFusionPass. Remove parameter (#20918)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
1 parent d4d3094 commit 37a7d5d

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

tests/compile/test_fusion_all_reduce.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
132132
dtype=dtype,
133133
seed=42)
134134

135-
all_reduce_fusion_pass = AllReduceFusionPass(
136-
vllm_config, vllm_config.compilation_config.pass_config.
137-
fi_allreduce_fusion_max_token_num)
135+
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
138136
backend = TestBackend(all_reduce_fusion_pass)
139137

140138
model = test_model_cls(hidden_size)

vllm/compilation/collective_fusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def replacement(residual: torch.Tensor, input: torch.Tensor,
397397

398398
class AllReduceFusionPass(VllmInductorPass):
399399

400-
def __init__(self, config: VllmConfig, max_token_num: int):
400+
def __init__(self, config: VllmConfig):
401401
super().__init__(config)
402402
self.disabled = True
403403
self.tp_size = get_tensor_model_parallel_world_size()
@@ -429,7 +429,8 @@ def __init__(self, config: VllmConfig, max_token_num: int):
429429
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
430430
tp_rank=rank,
431431
tp_size=self.tp_size,
432-
max_token_num=max_token_num,
432+
max_token_num=config.compilation_config.pass_config.
433+
fi_allreduce_fusion_max_token_num,
433434
hidden_dim=self.hidden_dim,
434435
group=self.group,
435436
use_fp32_lamport=use_fp32_lamport,
@@ -441,7 +442,8 @@ def __init__(self, config: VllmConfig, max_token_num: int):
441442
rank=rank,
442443
world_size=self.tp_size,
443444
use_fp32_lamport=use_fp32_lamport,
444-
max_token_num=max_token_num,
445+
max_token_num=config.compilation_config.pass_config.
446+
fi_allreduce_fusion_max_token_num,
445447
)
446448

447449
for epsilon in [1e-5, 1e-6]:

vllm/compilation/pass_manager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def configure(self, config: VllmConfig):
6363
if self.pass_config.enable_attn_fusion:
6464
self.passes += [AttnFusionPass(config)]
6565
if self.pass_config.enable_fi_allreduce_fusion:
66-
self.passes += [
67-
AllReduceFusionPass(
68-
config, self.pass_config.fi_allreduce_fusion_max_token_num)
69-
]
66+
self.passes += [AllReduceFusionPass(config)]
7067
self.fix_functionalization = FixFunctionalizationPass(config)
7168

7269
def add(self, pass_: InductorPass):

0 commit comments

Comments
 (0)