Skip to content

Commit cc80aea

Browse files
authored
[Feature] Support sequence parallelism for static fp8 quantization (vllm-project#19181)
Signed-off-by: cascade812 <cascade812@outlook.com>
1 parent ed1fb0f commit cc80aea

File tree

7 files changed

+534
-198
lines changed

7 files changed

+534
-198
lines changed

tests/compile/test_sequence_parallelism.py

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,25 @@
66

77
import vllm.envs as envs
88
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
9+
from vllm.compilation.fusion import FusionPass
910
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
11+
from vllm.compilation.noop_elimination import NoOpEliminationPass
1012
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
1113
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
1214
PassConfig, VllmConfig)
1315
from vllm.distributed import tensor_model_parallel_all_reduce
1416
from vllm.distributed.parallel_state import (init_distributed_environment,
1517
initialize_model_parallel)
1618
from vllm.model_executor.layers.layernorm import RMSNorm
19+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20+
Fp8LinearOp)
1721
from vllm.platforms import current_platform
1822
from vllm.utils import update_environment_variables
1923

2024
from ..utils import multi_gpu_test
2125
from .backend import TestBackend
2226

27+
FP8_DTYPE = current_platform.fp8_dtype()
2328
prompts = [
2429
"Hello, my name is",
2530
"The president of the United States is",
@@ -30,13 +35,16 @@
3035

3136
class TestModel(torch.nn.Module):
3237

33-
def __init__(self, hidden_size=16, intermediate_size=32):
38+
def __init__(self,
39+
hidden_size=16,
40+
intermediate_size=32,
41+
vllm_config: VllmConfig = None):
3442
super().__init__()
3543
self.hidden_size = hidden_size
3644
self.intermediate_size = intermediate_size
3745
self.gate_proj = torch.nn.Parameter(
3846
torch.empty((intermediate_size, hidden_size)))
39-
self.norm = RMSNorm(hidden_size, 1e-05)
47+
self.norm = RMSNorm(intermediate_size, 1e-05)
4048
# Initialize weights
4149
torch.nn.init.normal_(self.gate_proj, std=0.02)
4250

@@ -79,32 +87,138 @@ def ops_in_model(self):
7987
return [torch.ops._C.fused_add_rms_norm.default]
8088

8189

90+
class TestQuantModel(torch.nn.Module):
91+
92+
def __init__(self,
93+
hidden_size=16,
94+
intermediate_size=32,
95+
vllm_config: VllmConfig = None):
96+
super().__init__()
97+
self.hidden_size = hidden_size
98+
self.intermediate_size = intermediate_size
99+
self.vllm_config = vllm_config
100+
self.gate_proj = torch.nn.Parameter(torch.empty(
101+
(intermediate_size, hidden_size)),
102+
requires_grad=False)
103+
self.norm = RMSNorm(intermediate_size, 1e-05)
104+
# Initialize weights
105+
torch.nn.init.normal_(self.gate_proj, std=0.02)
106+
107+
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
108+
use_per_token_if_dynamic=False)
109+
110+
self.scale = torch.rand(1, dtype=torch.float32)
111+
# Create a weight that is compatible with torch._scaled_mm,
112+
# which expects a column-major layout.
113+
self.w = torch.rand(hidden_size,
114+
intermediate_size).to(dtype=FP8_DTYPE).t()
115+
self.wscale = torch.rand(1, dtype=torch.float32)
116+
117+
def forward(self, hidden_states, residual):
118+
"""
119+
Forward pass implementing the operations in the FX graph
120+
121+
Args:
122+
hidden_states: Input tensor
123+
residual: Residual tensor from previous layer
124+
125+
Returns:
126+
Tuple containing the output tensor
127+
"""
128+
# Reshape input
129+
view = hidden_states.reshape(-1, self.hidden_size)
130+
131+
#matrix multiplication
132+
permute = self.gate_proj.permute(1, 0)
133+
mm = torch.mm(view, permute)
134+
135+
# Tensor parallel all-reduce
136+
all_reduce = tensor_model_parallel_all_reduce(mm)
137+
138+
# layer normalization
139+
norm_output, residual_output = self.norm(all_reduce, residual)
140+
141+
# for static input quantization
142+
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
143+
fp8_linear_result = self.fp8_linear.apply(norm_output,
144+
self.w,
145+
self.wscale,
146+
input_scale=self.scale.to(
147+
norm_output.device))
148+
149+
return fp8_linear_result, residual_output
150+
151+
def ops_in_model_before(self):
152+
ops_to_remove = [torch.ops.vllm.all_reduce.default
153+
] # Always removed by SP
154+
# The following are only removed if fusion happens
155+
if self.vllm_config and self.vllm_config.compilation_config \
156+
.pass_config.enable_fusion:
157+
ops_to_remove.extend([
158+
torch.ops._C.fused_add_rms_norm.default,
159+
torch.ops._C.static_scaled_fp8_quant.default,
160+
])
161+
return ops_to_remove
162+
163+
def ops_in_model_after(self):
164+
ops_to_add = [
165+
torch.ops.vllm.reduce_scatter.default,
166+
torch.ops.vllm.all_gather.default
167+
]
168+
# The following is only added if fusion happens
169+
if self.vllm_config and self.vllm_config.compilation_config \
170+
.pass_config.enable_fusion:
171+
ops_to_add.append(
172+
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
173+
return ops_to_add
174+
175+
def ops_in_model(self):
176+
if self.vllm_config and self.vllm_config.compilation_config \
177+
.pass_config.enable_fusion:
178+
# If fusion happens, the fused op is the one
179+
# we check for (de)functionalization
180+
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
181+
] # noqa: E501
182+
else:
183+
# If no fusion, the original ops are checked
184+
return [
185+
torch.ops._C.fused_add_rms_norm.default,
186+
# TODO functionalization pass does not handle this yet
187+
# torch.ops._C.static_scaled_fp8_quant.default,
188+
]
189+
190+
82191
@multi_gpu_test(num_gpus=2)
192+
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
83193
@pytest.mark.parametrize("batch_size", [8])
84194
@pytest.mark.parametrize("seq_len", [16])
85195
@pytest.mark.parametrize("hidden_size", [16])
86196
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
197+
@pytest.mark.parametrize("enable_fusion", [True, False])
87198
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
88199
reason="Only test on CUDA")
89-
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
90-
hidden_size: int, dtype: torch.dtype):
200+
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
201+
batch_size: int, seq_len: int,
202+
hidden_size: int, dtype: torch.dtype,
203+
enable_fusion: bool):
91204
num_processes = 2
92205

93206
def run_torch_spawn(fn, nprocs):
94207
# need to use torch.mp.spawn otherwise will have problems with
95208
# torch.distributed and cuda
96209
torch.multiprocessing.spawn(fn,
97-
args=(num_processes, batch_size, seq_len,
98-
hidden_size, dtype),
210+
args=(num_processes, test_model_cls,
211+
batch_size, seq_len, hidden_size,
212+
dtype, enable_fusion),
99213
nprocs=nprocs)
100214

101215
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
102216

103217

104-
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
105-
batch_size: int, seq_len: int,
106-
hidden_size: int,
107-
dtype: torch.dtype):
218+
def sequence_parallelism_pass_on_test_model(
219+
local_rank: int, world_size: int,
220+
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
221+
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
108222
current_platform.seed_everything(0)
109223

110224
device = torch.device(f"cuda:{local_rank}")
@@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
127241
# configure vllm config for SequenceParallelismPass
128242
vllm_config = VllmConfig()
129243
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
130-
enable_sequence_parallelism=True))
244+
enable_sequence_parallelism=True,
245+
enable_fusion=enable_fusion,
246+
enable_noop=True)) # NoOp needed for fusion
131247
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
132248

133249
# this is a fake model name to construct the model config
134250
# in the vllm_config, it's not really used.
135-
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
136-
vllm_config.model_config = ModelConfig(model=model,
251+
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
252+
vllm_config.model_config = ModelConfig(model=model_name,
137253
task="auto",
138-
tokenizer=model,
254+
tokenizer=model_name,
139255
tokenizer_mode="auto",
140256
trust_remote_code=True,
141257
dtype=dtype,
142258
seed=42)
143259

144260
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
145-
backend_no_func = TestBackend(sequence_parallelism_pass)
261+
noop_pass = NoOpEliminationPass(vllm_config)
146262
func_pass = FixFunctionalizationPass(vllm_config)
147-
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
148263

149-
model = TestModel(hidden_size, hidden_size * 2)
264+
passes_for_backend = [noop_pass, sequence_parallelism_pass]
265+
266+
if enable_fusion:
267+
fusion_pass = FusionPass.instance(vllm_config)
268+
passes_for_backend.append(fusion_pass)
269+
270+
backend_no_func = TestBackend(*passes_for_backend)
271+
backend_func = TestBackend(*passes_for_backend, func_pass)
272+
273+
model = test_model_cls(hidden_size,
274+
hidden_size * 2,
275+
vllm_config=vllm_config)
276+
150277
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
151278
dtype=dtype)
152279
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)

tests/distributed/test_sequence_parallel.py

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
class ParallelSetup(NamedTuple):
2929
tp_size: int
3030
pp_size: int
31-
sp_enabled: bool
31+
enable_fusion: bool
3232
eager_mode: bool
3333
chunked_prefill: bool
3434

@@ -67,49 +67,18 @@ def detailed(
6767
task: TaskOption = "auto",
6868
load_format: Optional[str] = None,
6969
):
70+
parallel_setups = []
71+
for eager_mode_val in [False, True]:
72+
for pp_multiplier in [1, 2]:
73+
for chunked_prefill_val in [False, True]:
74+
parallel_setups.append(
75+
ParallelSetup(tp_size=tp_base,
76+
pp_size=pp_multiplier * pp_base,
77+
enable_fusion=False,
78+
eager_mode=eager_mode_val,
79+
chunked_prefill=chunked_prefill_val))
7080
return SPTestSettings(
71-
parallel_setups=[
72-
ParallelSetup(tp_size=tp_base,
73-
pp_size=pp_base,
74-
sp_enabled=True,
75-
eager_mode=False,
76-
chunked_prefill=False),
77-
ParallelSetup(tp_size=tp_base,
78-
pp_size=pp_base,
79-
sp_enabled=True,
80-
eager_mode=False,
81-
chunked_prefill=True),
82-
ParallelSetup(tp_size=tp_base,
83-
pp_size=pp_base,
84-
sp_enabled=True,
85-
eager_mode=True,
86-
chunked_prefill=False),
87-
ParallelSetup(tp_size=tp_base,
88-
pp_size=pp_base,
89-
sp_enabled=True,
90-
eager_mode=True,
91-
chunked_prefill=True),
92-
ParallelSetup(tp_size=tp_base,
93-
pp_size=2 * pp_base,
94-
sp_enabled=True,
95-
eager_mode=False,
96-
chunked_prefill=False),
97-
ParallelSetup(tp_size=tp_base,
98-
pp_size=2 * pp_base,
99-
sp_enabled=True,
100-
eager_mode=False,
101-
chunked_prefill=True),
102-
ParallelSetup(tp_size=tp_base,
103-
pp_size=2 * pp_base,
104-
sp_enabled=True,
105-
eager_mode=True,
106-
chunked_prefill=False),
107-
ParallelSetup(tp_size=tp_base,
108-
pp_size=2 * pp_base,
109-
sp_enabled=True,
110-
eager_mode=True,
111-
chunked_prefill=True)
112-
],
81+
parallel_setups=parallel_setups,
11382
distributed_backends=["mp", "ray"],
11483
vllm_major_versions=["1", "1"],
11584
task=task,
@@ -126,19 +95,44 @@ def fast(
12695
multi_node_only: bool = False,
12796
load_format: Optional[str] = None,
12897
):
98+
parallel_setups = []
99+
for eager_mode_val in [False, True]:
100+
for pp_multiplier in [1, 2]:
101+
for chunked_prefill_val in [False, True]:
102+
parallel_setups.append(
103+
ParallelSetup(tp_size=tp_base,
104+
pp_size=pp_multiplier * pp_base,
105+
enable_fusion=False,
106+
eager_mode=eager_mode_val,
107+
chunked_prefill=chunked_prefill_val))
129108
return SPTestSettings(
130-
parallel_setups=[
109+
parallel_setups=parallel_setups,
110+
distributed_backends=["mp", "ray"],
111+
vllm_major_versions=["1", "1"],
112+
task=task,
113+
test_options=SPTestOptions(multi_node_only=multi_node_only,
114+
load_format=load_format),
115+
)
116+
117+
@staticmethod
118+
def fp8_quant(
119+
*,
120+
tp_base: int = 2,
121+
pp_base: int = 1,
122+
task: TaskOption = "auto",
123+
multi_node_only: bool = False,
124+
load_format: Optional[str] = None,
125+
):
126+
parallel_setups = []
127+
for fusion_val in [False, True]:
128+
parallel_setups.append(
131129
ParallelSetup(tp_size=tp_base,
132130
pp_size=pp_base,
133-
sp_enabled=True,
134-
eager_mode=False,
135-
chunked_prefill=False),
136-
ParallelSetup(tp_size=tp_base,
137-
pp_size=2 * pp_base,
138-
sp_enabled=True,
139-
eager_mode=False,
140-
chunked_prefill=False),
141-
],
131+
enable_fusion=fusion_val,
132+
eager_mode=True,
133+
chunked_prefill=False))
134+
return SPTestSettings(
135+
parallel_setups=parallel_setups,
142136
distributed_backends=["mp", "ray"],
143137
vllm_major_versions=["1", "1"],
144138
task=task,
@@ -171,7 +165,7 @@ def _compare_sp(
171165
(
172166
tp_size,
173167
pp_size,
174-
sp_enabled,
168+
enable_fusion,
175169
eager_mode,
176170
chunked_prefill,
177171
) = parallel_setup
@@ -240,9 +234,9 @@ def _compare_sp(
240234
'compile_sizes': [4, 8],
241235
'splitting_ops': [],
242236
'pass_config': {
243-
'enable_sequence_parallelism': sp_enabled,
237+
'enable_sequence_parallelism': True,
238+
'enable_fusion': enable_fusion,
244239
'enable_noop': True,
245-
'enable_fusion': True,
246240
},
247241
}
248242

@@ -291,12 +285,14 @@ def _compare_sp(
291285
SP_TEXT_GENERATION_MODELS = {
292286
# [Decoder-only]
293287
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
288+
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
294289
}
295290

296291
SP_TEST_MODELS = [
297292
# TODO support other models
298293
# [LANGUAGE GENERATION]
299294
"meta-llama/Llama-3.2-1B-Instruct",
295+
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
300296
]
301297

302298

0 commit comments

Comments
 (0)