6
6
7
7
import vllm .envs as envs
8
8
from vllm .compilation .fix_functionalization import FixFunctionalizationPass
9
+ from vllm .compilation .fusion import FusionPass
9
10
from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
11
+ from vllm .compilation .noop_elimination import NoOpEliminationPass
10
12
from vllm .compilation .sequence_parallelism import SequenceParallelismPass
11
13
from vllm .config import (CompilationConfig , DeviceConfig , ModelConfig ,
12
14
PassConfig , VllmConfig )
13
15
from vllm .distributed import tensor_model_parallel_all_reduce
14
16
from vllm .distributed .parallel_state import (init_distributed_environment ,
15
17
initialize_model_parallel )
16
18
from vllm .model_executor .layers .layernorm import RMSNorm
19
+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20
+ Fp8LinearOp )
17
21
from vllm .platforms import current_platform
18
22
from vllm .utils import update_environment_variables
19
23
20
24
from ..utils import multi_gpu_test
21
25
from .backend import TestBackend
22
26
27
+ FP8_DTYPE = current_platform .fp8_dtype ()
23
28
prompts = [
24
29
"Hello, my name is" ,
25
30
"The president of the United States is" ,
30
35
31
36
class TestModel (torch .nn .Module ):
32
37
33
- def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
38
+ def __init__ (self ,
39
+ hidden_size = 16 ,
40
+ intermediate_size = 32 ,
41
+ vllm_config : VllmConfig = None ):
34
42
super ().__init__ ()
35
43
self .hidden_size = hidden_size
36
44
self .intermediate_size = intermediate_size
37
45
self .gate_proj = torch .nn .Parameter (
38
46
torch .empty ((intermediate_size , hidden_size )))
39
- self .norm = RMSNorm (hidden_size , 1e-05 )
47
+ self .norm = RMSNorm (intermediate_size , 1e-05 )
40
48
# Initialize weights
41
49
torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
42
50
@@ -79,32 +87,138 @@ def ops_in_model(self):
79
87
return [torch .ops ._C .fused_add_rms_norm .default ]
80
88
81
89
90
+ class TestQuantModel (torch .nn .Module ):
91
+
92
+ def __init__ (self ,
93
+ hidden_size = 16 ,
94
+ intermediate_size = 32 ,
95
+ vllm_config : VllmConfig = None ):
96
+ super ().__init__ ()
97
+ self .hidden_size = hidden_size
98
+ self .intermediate_size = intermediate_size
99
+ self .vllm_config = vllm_config
100
+ self .gate_proj = torch .nn .Parameter (torch .empty (
101
+ (intermediate_size , hidden_size )),
102
+ requires_grad = False )
103
+ self .norm = RMSNorm (intermediate_size , 1e-05 )
104
+ # Initialize weights
105
+ torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
106
+
107
+ self .fp8_linear = Fp8LinearOp (cutlass_fp8_supported = True ,
108
+ use_per_token_if_dynamic = False )
109
+
110
+ self .scale = torch .rand (1 , dtype = torch .float32 )
111
+ # Create a weight that is compatible with torch._scaled_mm,
112
+ # which expects a column-major layout.
113
+ self .w = torch .rand (hidden_size ,
114
+ intermediate_size ).to (dtype = FP8_DTYPE ).t ()
115
+ self .wscale = torch .rand (1 , dtype = torch .float32 )
116
+
117
+ def forward (self , hidden_states , residual ):
118
+ """
119
+ Forward pass implementing the operations in the FX graph
120
+
121
+ Args:
122
+ hidden_states: Input tensor
123
+ residual: Residual tensor from previous layer
124
+
125
+ Returns:
126
+ Tuple containing the output tensor
127
+ """
128
+ # Reshape input
129
+ view = hidden_states .reshape (- 1 , self .hidden_size )
130
+
131
+ #matrix multiplication
132
+ permute = self .gate_proj .permute (1 , 0 )
133
+ mm = torch .mm (view , permute )
134
+
135
+ # Tensor parallel all-reduce
136
+ all_reduce = tensor_model_parallel_all_reduce (mm )
137
+
138
+ # layer normalization
139
+ norm_output , residual_output = self .norm (all_reduce , residual )
140
+
141
+ # for static input quantization
142
+ # self.fp8_linear is initialized with use_per_token_if_dynamic=False
143
+ fp8_linear_result = self .fp8_linear .apply (norm_output ,
144
+ self .w ,
145
+ self .wscale ,
146
+ input_scale = self .scale .to (
147
+ norm_output .device ))
148
+
149
+ return fp8_linear_result , residual_output
150
+
151
+ def ops_in_model_before (self ):
152
+ ops_to_remove = [torch .ops .vllm .all_reduce .default
153
+ ] # Always removed by SP
154
+ # The following are only removed if fusion happens
155
+ if self .vllm_config and self .vllm_config .compilation_config \
156
+ .pass_config .enable_fusion :
157
+ ops_to_remove .extend ([
158
+ torch .ops ._C .fused_add_rms_norm .default ,
159
+ torch .ops ._C .static_scaled_fp8_quant .default ,
160
+ ])
161
+ return ops_to_remove
162
+
163
+ def ops_in_model_after (self ):
164
+ ops_to_add = [
165
+ torch .ops .vllm .reduce_scatter .default ,
166
+ torch .ops .vllm .all_gather .default
167
+ ]
168
+ # The following is only added if fusion happens
169
+ if self .vllm_config and self .vllm_config .compilation_config \
170
+ .pass_config .enable_fusion :
171
+ ops_to_add .append (
172
+ torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default )
173
+ return ops_to_add
174
+
175
+ def ops_in_model (self ):
176
+ if self .vllm_config and self .vllm_config .compilation_config \
177
+ .pass_config .enable_fusion :
178
+ # If fusion happens, the fused op is the one
179
+ # we check for (de)functionalization
180
+ return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
181
+ ] # noqa: E501
182
+ else :
183
+ # If no fusion, the original ops are checked
184
+ return [
185
+ torch .ops ._C .fused_add_rms_norm .default ,
186
+ # TODO functionalization pass does not handle this yet
187
+ # torch.ops._C.static_scaled_fp8_quant.default,
188
+ ]
189
+
190
+
82
191
@multi_gpu_test (num_gpus = 2 )
192
+ @pytest .mark .parametrize ("test_model_cls" , [TestModel , TestQuantModel ])
83
193
@pytest .mark .parametrize ("batch_size" , [8 ])
84
194
@pytest .mark .parametrize ("seq_len" , [16 ])
85
195
@pytest .mark .parametrize ("hidden_size" , [16 ])
86
196
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
197
+ @pytest .mark .parametrize ("enable_fusion" , [True , False ])
87
198
@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ],
88
199
reason = "Only test on CUDA" )
89
- def test_sequence_parallelism_pass (batch_size : int , seq_len : int ,
90
- hidden_size : int , dtype : torch .dtype ):
200
+ def test_sequence_parallelism_pass (test_model_cls : type [torch .nn .Module ],
201
+ batch_size : int , seq_len : int ,
202
+ hidden_size : int , dtype : torch .dtype ,
203
+ enable_fusion : bool ):
91
204
num_processes = 2
92
205
93
206
def run_torch_spawn (fn , nprocs ):
94
207
# need to use torch.mp.spawn otherwise will have problems with
95
208
# torch.distributed and cuda
96
209
torch .multiprocessing .spawn (fn ,
97
- args = (num_processes , batch_size , seq_len ,
98
- hidden_size , dtype ),
210
+ args = (num_processes , test_model_cls ,
211
+ batch_size , seq_len , hidden_size ,
212
+ dtype , enable_fusion ),
99
213
nprocs = nprocs )
100
214
101
215
run_torch_spawn (sequence_parallelism_pass_on_test_model , num_processes )
102
216
103
217
104
- def sequence_parallelism_pass_on_test_model (local_rank : int , world_size : int ,
105
- batch_size : int , seq_len : int ,
106
- hidden_size : int ,
107
- dtype : torch .dtype ):
218
+ def sequence_parallelism_pass_on_test_model (
219
+ local_rank : int , world_size : int ,
220
+ test_model_cls : type [ torch . nn . Module ], batch_size : int , seq_len : int ,
221
+ hidden_size : int , dtype : torch .dtype , enable_fusion : bool ):
108
222
current_platform .seed_everything (0 )
109
223
110
224
device = torch .device (f"cuda:{ local_rank } " )
@@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
127
241
# configure vllm config for SequenceParallelismPass
128
242
vllm_config = VllmConfig ()
129
243
vllm_config .compilation_config = CompilationConfig (pass_config = PassConfig (
130
- enable_sequence_parallelism = True ))
244
+ enable_sequence_parallelism = True ,
245
+ enable_fusion = enable_fusion ,
246
+ enable_noop = True )) # NoOp needed for fusion
131
247
vllm_config .device_config = DeviceConfig (device = torch .device ("cuda" ))
132
248
133
249
# this is a fake model name to construct the model config
134
250
# in the vllm_config, it's not really used.
135
- model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
136
- vllm_config .model_config = ModelConfig (model = model ,
251
+ model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
252
+ vllm_config .model_config = ModelConfig (model = model_name ,
137
253
task = "auto" ,
138
- tokenizer = model ,
254
+ tokenizer = model_name ,
139
255
tokenizer_mode = "auto" ,
140
256
trust_remote_code = True ,
141
257
dtype = dtype ,
142
258
seed = 42 )
143
259
144
260
sequence_parallelism_pass = SequenceParallelismPass (vllm_config )
145
- backend_no_func = TestBackend ( sequence_parallelism_pass )
261
+ noop_pass = NoOpEliminationPass ( vllm_config )
146
262
func_pass = FixFunctionalizationPass (vllm_config )
147
- backend_func = TestBackend (sequence_parallelism_pass , func_pass )
148
263
149
- model = TestModel (hidden_size , hidden_size * 2 )
264
+ passes_for_backend = [noop_pass , sequence_parallelism_pass ]
265
+
266
+ if enable_fusion :
267
+ fusion_pass = FusionPass .instance (vllm_config )
268
+ passes_for_backend .append (fusion_pass )
269
+
270
+ backend_no_func = TestBackend (* passes_for_backend )
271
+ backend_func = TestBackend (* passes_for_backend , func_pass )
272
+
273
+ model = test_model_cls (hidden_size ,
274
+ hidden_size * 2 ,
275
+ vllm_config = vllm_config )
276
+
150
277
hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
151
278
dtype = dtype )
152
279
residual = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
0 commit comments