|
23 | 23 | wrap_module_forward_quantized,
|
24 | 24 | )
|
25 | 25 | from compressed_tensors.quantization.quant_args import (
|
26 |
| - FP4_E2M1_DATA, |
27 | 26 | FP8_E4M3_DATA,
|
28 | 27 | ActivationOrdering,
|
29 | 28 | QuantizationArgs,
|
30 | 29 | QuantizationStrategy,
|
31 |
| - QuantizationType, |
32 | 30 | )
|
33 | 31 | from compressed_tensors.quantization.quant_config import QuantizationStatus
|
34 | 32 | from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
35 |
| -from compressed_tensors.quantization.utils import ( |
36 |
| - generate_global_scale, |
37 |
| - is_fp4, |
38 |
| - is_kv_cache_quant_scheme, |
39 |
| - iter_named_quantizable_modules, |
40 |
| -) |
| 33 | +from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme |
41 | 34 | from compressed_tensors.utils import (
|
42 | 35 | disable_hf_hook,
|
43 | 36 | get_execution_device,
|
44 | 37 | register_offload_parameter,
|
45 |
| - update_parameter_data, |
46 | 38 | )
|
47 | 39 | from torch.nn import Module, Parameter
|
48 | 40 |
|
|
51 | 43 | "initialize_module_for_quantization",
|
52 | 44 | "is_attention_module",
|
53 | 45 | "KVCacheScaleType",
|
54 |
| - "update_fused_layer_weight_global_scales", |
55 | 46 | ]
|
56 | 47 |
|
57 | 48 |
|
@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
|
162 | 153 | # initialize on execution device to avoid performing quantized ops on cpu
|
163 | 154 | device = get_execution_device(module)
|
164 | 155 |
|
165 |
| - # 1. Create global_scales for tensor_group |
| 156 | + # 1. Create global_scales for tensor_group - generates |
| 157 | + # a per tensor scale |
166 | 158 | if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
167 |
| - # TODO: should move to llmcompressor |
168 |
| - if base_name == "weight": |
169 |
| - # When applying weight-only FP4 quantization, generate a global_scale |
170 |
| - # This scale is applied during runtime to ensure that the generated |
171 |
| - # local scale falls properly within the FP8 range (i.e max value is FP8_max) |
172 |
| - # which is the expected dtype of NVFP4A16 scales |
173 |
| - value = generate_global_scale(input_tensor=module.weight) |
174 |
| - value = value.to(device) |
175 |
| - init_global_scale = Parameter(value, requires_grad=False) |
176 |
| - else: |
177 |
| - init_global_scale = Parameter( |
178 |
| - torch.empty(1, dtype=torch.float32, device=device), |
179 |
| - requires_grad=False, |
180 |
| - ) |
| 159 | + init_global_scale = Parameter( |
| 160 | + torch.empty(1, dtype=torch.float32, device=device), |
| 161 | + requires_grad=False, |
| 162 | + ) |
181 | 163 | register_offload_parameter(
|
182 | 164 | module, f"{base_name}_global_scale", init_global_scale
|
183 | 165 | )
|
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
|
258 | 240 | requires_grad=False,
|
259 | 241 | )
|
260 | 242 | register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
|
261 |
| - |
262 |
| - |
263 |
| -# TODO: Potentially introduce an argument to turn this off |
264 |
| -# Only relevant for NVFP4A16 currently |
265 |
| -def update_fused_layer_weight_global_scales(model: torch.nn.Module): |
266 |
| - """ |
267 |
| - When running NVFP4A16 quantization, update the global scale |
268 |
| - such that q,k,v layers are treated as one tensor with the same |
269 |
| - global_scale and gate_proj/up_proj layers are treated as one tensor |
270 |
| - with the same global scale. This is requirement currently being set |
271 |
| - by vLLM and may be removed in the future OR potentially make it |
272 |
| - an optional step. |
273 |
| -
|
274 |
| - :param model: model to quantize |
275 |
| - """ |
276 |
| - |
277 |
| - def _is_attention_module(module: Module): |
278 |
| - return "attention" in module.__class__.__name__.lower() and ( |
279 |
| - hasattr(module, "k_proj") |
280 |
| - or hasattr(module, "v_proj") |
281 |
| - or hasattr(module, "qkv_proj") |
282 |
| - ) |
283 |
| - |
284 |
| - def _is_mlp_module(module: Module): |
285 |
| - return "mlp" in module.__class__.__name__.lower() and ( |
286 |
| - hasattr(module, "gate_proj") or hasattr(module, "up_proj") |
287 |
| - ) |
288 |
| - |
289 |
| - def _valid_fp4_quant(layer_list: List[torch.nn.Linear]): |
290 |
| - """ |
291 |
| - Return True if all the linear layers in the layer_list are |
292 |
| - NVFP4A16 quantized. |
293 |
| - """ |
294 |
| - for layer in layer_list: |
295 |
| - scheme = getattr(layer, "quantization_scheme", None) |
296 |
| - if scheme is None: |
297 |
| - return False |
298 |
| - |
299 |
| - weight_quant_args = scheme.weights |
300 |
| - |
301 |
| - if weight_quant_args is None: |
302 |
| - return False |
303 |
| - |
304 |
| - if not is_fp4(quantization_args=weight_quant_args): |
305 |
| - return False |
306 |
| - return True |
307 |
| - |
308 |
| - for name, submodule in iter_named_quantizable_modules( |
309 |
| - model, |
310 |
| - include_attn=True, |
311 |
| - include_mlp=True, |
312 |
| - ): |
313 |
| - |
314 |
| - if _is_attention_module(submodule): |
315 |
| - # already fused/treated as one layer |
316 |
| - if hasattr(submodule, "qkv_proj"): |
317 |
| - continue |
318 |
| - |
319 |
| - if not _valid_fp4_quant( |
320 |
| - [submodule.q_proj, submodule.v_proj, submodule.k_proj] |
321 |
| - ): |
322 |
| - continue |
323 |
| - |
324 |
| - q_weight = submodule.q_proj.weight.data |
325 |
| - v_weight = submodule.v_proj.weight.data |
326 |
| - k_weight = submodule.k_proj.weight.data |
327 |
| - |
328 |
| - value = generate_global_scale( |
329 |
| - input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0) |
330 |
| - ) |
331 |
| - |
332 |
| - update_parameter_data(submodule.q_proj, value, "weight_global_scale") |
333 |
| - update_parameter_data(submodule.k_proj, value, "weight_global_scale") |
334 |
| - update_parameter_data(submodule.v_proj, value, "weight_global_scale") |
335 |
| - |
336 |
| - if _is_mlp_module(submodule): |
337 |
| - if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]): |
338 |
| - continue |
339 |
| - |
340 |
| - gate_data = submodule.gate_proj.weight.data |
341 |
| - up_data = submodule.up_proj.weight.data |
342 |
| - |
343 |
| - value = generate_global_scale( |
344 |
| - input_tensor=torch.cat((gate_data, up_data), dim=0) |
345 |
| - ) |
346 |
| - |
347 |
| - update_parameter_data(submodule.gate_proj, value, "weight_global_scale") |
348 |
| - update_parameter_data(submodule.up_proj, value, "weight_global_scale") |
0 commit comments