Skip to content

Commit fc0f41d

Browse files
ilmarkovilmarkov
andauthored
Integration SM100 FlashInfer fused allreduce RMSNorm (#20691)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
1 parent 7b828e3 commit fc0f41d

File tree

4 files changed

+514
-6
lines changed

4 files changed

+514
-6
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from importlib.util import find_spec
4+
5+
import pytest
6+
import torch
7+
8+
import vllm.envs as envs
9+
from vllm.compilation.collective_fusion import AllReduceFusionPass
10+
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
11+
ModelConfig, PassConfig, VllmConfig)
12+
from vllm.distributed import tensor_model_parallel_all_reduce
13+
from vllm.distributed.parallel_state import (init_distributed_environment,
14+
initialize_model_parallel)
15+
from vllm.model_executor.layers.layernorm import RMSNorm
16+
from vllm.platforms import current_platform
17+
from vllm.utils import update_environment_variables
18+
19+
from ..utils import multi_gpu_test
20+
from .backend import TestBackend
21+
22+
23+
class TestAllReduceRMSNormModel(torch.nn.Module):
24+
25+
def __init__(self, hidden_size=16, eps=1e-6):
26+
super().__init__()
27+
self.hidden_size = hidden_size
28+
self.eps = eps
29+
self.norm = RMSNorm(hidden_size, eps)
30+
31+
def forward(self, hidden_states, residual):
32+
view = hidden_states.reshape(-1, self.hidden_size)
33+
all_reduce = tensor_model_parallel_all_reduce(view)
34+
norm = self.norm(all_reduce)
35+
return norm
36+
37+
def ops_in_model_before(self):
38+
return [torch.ops.vllm.all_reduce.default]
39+
40+
def ops_in_model_after(self):
41+
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
42+
43+
44+
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
45+
46+
def __init__(self, hidden_size=16, eps=1e-6):
47+
super().__init__()
48+
self.hidden_size = hidden_size
49+
self.eps = eps
50+
self.norm = RMSNorm(hidden_size, eps)
51+
52+
def forward(self, hidden_states, residual):
53+
view = hidden_states.reshape(-1, self.hidden_size)
54+
all_reduce = tensor_model_parallel_all_reduce(view)
55+
norm, _ = self.norm(all_reduce, residual)
56+
return norm
57+
58+
def ops_in_model_before(self):
59+
return [torch.ops.vllm.all_reduce.default]
60+
61+
def ops_in_model_after(self):
62+
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
63+
64+
65+
@multi_gpu_test(num_gpus=2)
66+
@pytest.mark.parametrize(
67+
"test_model",
68+
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
69+
@pytest.mark.parametrize("batch_size", [8])
70+
@pytest.mark.parametrize("seq_len", [8])
71+
@pytest.mark.parametrize("hidden_size", [4096])
72+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
73+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
74+
reason="Only test on CUDA")
75+
@pytest.mark.skipif(not find_spec("flashinfer"),
76+
reason="flashinfer is not installed")
77+
@pytest.mark.skipif(not current_platform.is_device_capability(100),
78+
reason="Only test on SM100")
79+
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
80+
batch_size: int, seq_len: int,
81+
hidden_size: int, dtype: torch.dtype):
82+
num_processes = 2
83+
84+
def run_torch_spawn(fn, nprocs):
85+
torch.multiprocessing.spawn(fn,
86+
args=(num_processes, test_model,
87+
batch_size, seq_len, hidden_size,
88+
dtype),
89+
nprocs=nprocs)
90+
91+
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
92+
93+
94+
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
95+
test_model_cls: torch.nn.Module,
96+
batch_size: int, seq_len: int,
97+
hidden_size: int, dtype: torch.dtype):
98+
current_platform.seed_everything(0)
99+
100+
device = torch.device(f"cuda:{local_rank}")
101+
torch.cuda.set_device(device)
102+
torch.set_default_device(device)
103+
torch.set_default_dtype(dtype)
104+
105+
update_environment_variables({
106+
'RANK': str(local_rank),
107+
'LOCAL_RANK': str(local_rank),
108+
'WORLD_SIZE': str(world_size),
109+
'MASTER_ADDR': 'localhost',
110+
'MASTER_PORT': '12345',
111+
})
112+
113+
init_distributed_environment()
114+
initialize_model_parallel(tensor_model_parallel_size=world_size)
115+
116+
vllm_config = VllmConfig(
117+
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
118+
custom_ops=["+rms_norm"],
119+
compile_sizes=[2, 4, 8]))
120+
vllm_config.compilation_config.pass_config = PassConfig(
121+
enable_fi_allreduce_fusion=True)
122+
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
123+
124+
# this is a fake model name to construct the model config
125+
# in the vllm_config, it's not really used.
126+
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
127+
vllm_config.model_config = ModelConfig(model=model_name,
128+
task="auto",
129+
tokenizer=model_name,
130+
tokenizer_mode="auto",
131+
trust_remote_code=True,
132+
dtype=dtype,
133+
seed=42)
134+
135+
all_reduce_fusion_pass = AllReduceFusionPass(
136+
vllm_config, vllm_config.compilation_config.pass_config.
137+
fi_allreduce_fusion_max_token_num)
138+
backend = TestBackend(all_reduce_fusion_pass)
139+
140+
model = test_model_cls(hidden_size)
141+
142+
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
143+
requires_grad=False)
144+
residual = torch.randn((batch_size * seq_len, hidden_size),
145+
requires_grad=False)
146+
147+
compiled_model = torch.compile(model, backend=backend)
148+
compiled_model(hidden_states, residual)
149+
150+
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
151+
backend.check_after_ops(model.ops_in_model_after())
152+
del all_reduce_fusion_pass

0 commit comments

Comments
 (0)