Skip to content

Commit 7a7abdf

Browse files
committed
fix perm math
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4b55733 commit 7a7abdf

File tree

23 files changed

+226
-278
lines changed

23 files changed

+226
-278
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
align_module_device,
5151
delete_offload_parameter,
5252
get_execution_device,
53+
get_offloaded_device,
5354
get_safetensors_folder,
5455
has_offloaded_params,
5556
merge_names,
@@ -408,16 +409,17 @@ def compress_model(self, model: Module):
408409
)
409410

410411
# remove any existing parameters
411-
device = get_execution_device(module)
412+
exec_device = get_execution_device(module)
413+
offload_device = get_offloaded_device(module)
412414
for name, _ in list(module.named_parameters()):
413-
delattr(module, name)
415+
delete_offload_parameter(module, name)
414416

415417
# replace with compressed parameters
416418
for name, value in state_dict.items():
417419
name = name.removeprefix(f"{prefix}.")
418-
value = value.to(device)
420+
value = value.to(exec_device)
419421
param = torch.nn.Parameter(value, requires_grad=False)
420-
register_offload_parameter(module, name, param)
422+
register_offload_parameter(module, name, param, offload_device)
421423

422424
module.quantization_status = QuantizationStatus.COMPRESSED
423425

@@ -460,30 +462,26 @@ def decompress_model(self, model: Module):
460462

461463
# quantization second
462464
if prefix in module_to_scheme:
463-
generator = self.quantization_compressor.decompress_from_state_dict(
464-
state_dict,
465-
names_to_scheme=module_to_scheme,
465+
state_dict = (
466+
self.quantization_compressor.decompress_module_from_state_dict(
467+
prefix,
468+
state_dict,
469+
scheme=module_to_scheme[prefix],
470+
)
466471
)
467-
# generates (mod_path, {param_name, param_val})
468-
# of compressed params and used params, but not unused params
469-
# some used params are removed by get_unexpected_file_keys
470-
state_dict = {
471-
merge_names(module_path, param_name): param_value
472-
for module_path, compressed_data in generator
473-
for param_name, param_value in compressed_data.items()
474-
}
475472

476473
# remove any existing parameters
477-
device = get_execution_device(module)
474+
exec_device = get_execution_device(module)
475+
offload_device = get_offloaded_device(module)
478476
for name, _ in list(module.named_parameters()):
479477
delete_offload_parameter(module, name)
480478

481479
# replace with decompressed parameters
482480
for name, value in state_dict.items():
483481
name = name.removeprefix(f"{prefix}.")
484-
value = value.to(device)
482+
value = value.to(exec_device)
485483
param = torch.nn.Parameter(value, requires_grad=False)
486-
register_offload_parameter(module, name, param)
484+
register_offload_parameter(module, name, param, offload_device)
487485

488486
module.quantization_status = QuantizationStatus.FROZEN
489487

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_nested_weight_mappings,
2525
merge_names,
2626
)
27+
from compressed_tensors.utils.safetensors_load import match_param_name
2728
from safetensors import safe_open
2829
from torch import Tensor
2930
from tqdm import tqdm
@@ -223,9 +224,7 @@ def decompress_from_state_dict(
223224
state_dict, self.compression_param_names
224225
)
225226
for module_path in weight_mappings.keys():
226-
weight_data = {}
227-
for param_name, param_value in weight_mappings[module_path].items():
228-
weight_data[param_name] = param_value
227+
weight_data = weight_mappings[module_path].copy()
229228

230229
if "weight_scale" in weight_data:
231230
quant_args = names_to_scheme[module_path].weights
@@ -234,3 +233,31 @@ def decompress_from_state_dict(
234233
)
235234
weight_data["weight"] = decompressed
236235
yield module_path, weight_data
236+
237+
def decompress_module_from_state_dict(
238+
self,
239+
prefix: str,
240+
state_dict: Dict[str, torch.Tensor],
241+
scheme: QuantizationScheme,
242+
) -> Dict[str, torch.Tensor]:
243+
"""
244+
Only used by in-memory decompression pathways to decompress the parameters of
245+
one module
246+
247+
:param prefix: prefix of state_dict, typically the path to the module
248+
:param state_dict: state dict containing module parameter values
249+
:param scheme: quantization scheme of module to decompress
250+
:return: state dict with weight decompressed if applicable
251+
"""
252+
state_dict = {
253+
key.removeprefix(f"{prefix}."): value for key, value in state_dict.items()
254+
}
255+
256+
if "weight_scale" in state_dict:
257+
state_dict["weight"] = self.decompress_weight(
258+
compressed_data=state_dict, quantization_args=scheme.weights
259+
)
260+
261+
state_dict = {f"{prefix}.{key}": value for key, value in state_dict.items()}
262+
263+
return state_dict

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@
2727
)
2828
from compressed_tensors.quantization.lifecycle.initialize import (
2929
initialize_module_for_quantization,
30-
update_fused_layer_weight_global_scales,
31-
)
32-
from compressed_tensors.quantization.quant_args import (
33-
FP4_E2M1_DATA,
34-
FP8_E4M3_DATA,
35-
QuantizationArgs,
36-
QuantizationType,
3730
)
31+
from compressed_tensors.quantization.quant_args import QuantizationArgs
3832
from compressed_tensors.quantization.quant_config import (
3933
QuantizationConfig,
4034
QuantizationStatus,
@@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
272266
)
273267
)
274268

275-
if status == QuantizationStatus.INITIALIZED:
276-
update_fused_layer_weight_global_scales(model)
277-
278269
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
279270
model.apply(compress_quantized_weights)
280271

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
DynamicType,
2222
QuantizationArgs,
2323
QuantizationStrategy,
24-
QuantizationType,
2524
round_to_quantized_type,
2625
)
2726
from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -227,31 +226,42 @@ def _process_quantization(
227226
perm = torch.argsort(g_idx)
228227
x = safe_permute(x, perm, dim=1)
229228

230-
# TODO: experiment with vectorizing for loop for performance
231-
end = 0
232-
for index, group_count in enumerate(group_sizes):
233-
sc = scale[:, index].view(-1, 1)
234-
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
235-
236-
start = end
237-
end = start + group_count
238-
if do_quantize:
239-
output[:, start:end] = _quantize(
240-
x=x[:, start:end],
241-
scale=sc,
242-
zero_point=zp,
243-
q_min=q_min,
244-
q_max=q_max,
245-
args=args,
246-
dtype=dtype,
247-
global_scale=global_scale,
248-
)
229+
x = torch.reshape(
230+
x,
231+
(
232+
x.shape[0],
233+
ceil(x.shape[1] / group_size),
234+
group_size,
235+
),
236+
)
249237

250-
if do_dequantize:
251-
input = output[:, start:end] if do_quantize else x[:, start:end]
252-
output[:, start:end] = _dequantize(
253-
x_q=input, scale=sc, zero_point=zp, global_scale=global_scale
254-
)
238+
if do_quantize:
239+
output = _quantize(
240+
x=x,
241+
scale=scale.unsqueeze(-1),
242+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
243+
dtype=dtype,
244+
global_scale=global_scale,
245+
q_min=q_min,
246+
q_max=q_max,
247+
args=args,
248+
)
249+
250+
if do_dequantize:
251+
input = output if do_quantize else x
252+
output = _dequantize(
253+
x_q=input,
254+
scale=scale.unsqueeze(-1),
255+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
256+
global_scale=global_scale,
257+
)
258+
259+
output = torch.reshape(
260+
output,
261+
(output.shape[0], output.shape[1] * output.shape[2]),
262+
)
263+
264+
output = output.to(output_dtype)
255265

256266
if not is_column_order:
257267
output = safe_permute(output, torch.argsort(perm), dim=1)
@@ -394,7 +404,7 @@ def _quantize(
394404

395405
# if a global scale is optionally provided, use it
396406
# to further scale the local `scale` parameter
397-
if global_scale:
407+
if global_scale is not None:
398408
scale = scale.to(global_scale.dtype) / global_scale
399409

400410
scaled = x / scale
@@ -427,7 +437,7 @@ def _dequantize(
427437

428438
# if a global scale is optionally provided, use it
429439
# to further scale the local `scale` parameter
430-
if global_scale:
440+
if global_scale is not None:
431441
scale = scale.to(global_scale.dtype) / global_scale
432442

433443
dequant_value = x_q.to(scale.dtype)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,18 @@
2323
wrap_module_forward_quantized,
2424
)
2525
from compressed_tensors.quantization.quant_args import (
26-
FP4_E2M1_DATA,
2726
FP8_E4M3_DATA,
2827
ActivationOrdering,
2928
QuantizationArgs,
3029
QuantizationStrategy,
31-
QuantizationType,
3230
)
3331
from compressed_tensors.quantization.quant_config import QuantizationStatus
3432
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35-
from compressed_tensors.quantization.utils import (
36-
generate_global_scale,
37-
is_fp4,
38-
is_kv_cache_quant_scheme,
39-
iter_named_quantizable_modules,
40-
)
33+
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
4134
from compressed_tensors.utils import (
4235
disable_hf_hook,
4336
get_execution_device,
4437
register_offload_parameter,
45-
update_parameter_data,
4638
)
4739
from torch.nn import Module, Parameter
4840

@@ -51,7 +43,6 @@
5143
"initialize_module_for_quantization",
5244
"is_attention_module",
5345
"KVCacheScaleType",
54-
"update_fused_layer_weight_global_scales",
5546
]
5647

5748

@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
162153
# initialize on execution device to avoid performing quantized ops on cpu
163154
device = get_execution_device(module)
164155

165-
# 1. Create global_scales for tensor_group
156+
# 1. Create global_scales for tensor_group - generates
157+
# a per tensor scale
166158
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167-
# TODO: should move to llmcompressor
168-
if base_name == "weight":
169-
# When applying weight-only FP4 quantization, generate a global_scale
170-
# This scale is applied during runtime to ensure that the generated
171-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
172-
# which is the expected dtype of NVFP4A16 scales
173-
value = generate_global_scale(input_tensor=module.weight)
174-
value = value.to(device)
175-
init_global_scale = Parameter(value, requires_grad=False)
176-
else:
177-
init_global_scale = Parameter(
178-
torch.empty(1, dtype=torch.float32, device=device),
179-
requires_grad=False,
180-
)
159+
init_global_scale = Parameter(
160+
torch.empty(1, dtype=torch.float32, device=device),
161+
requires_grad=False,
162+
)
181163
register_offload_parameter(
182164
module, f"{base_name}_global_scale", init_global_scale
183165
)
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
258240
requires_grad=False,
259241
)
260242
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
261-
262-
263-
# TODO: Potentially introduce an argument to turn this off
264-
# Only relevant for NVFP4A16 currently
265-
def update_fused_layer_weight_global_scales(model: torch.nn.Module):
266-
"""
267-
When running NVFP4A16 quantization, update the global scale
268-
such that q,k,v layers are treated as one tensor with the same
269-
global_scale and gate_proj/up_proj layers are treated as one tensor
270-
with the same global scale. This is requirement currently being set
271-
by vLLM and may be removed in the future OR potentially make it
272-
an optional step.
273-
274-
:param model: model to quantize
275-
"""
276-
277-
def _is_attention_module(module: Module):
278-
return "attention" in module.__class__.__name__.lower() and (
279-
hasattr(module, "k_proj")
280-
or hasattr(module, "v_proj")
281-
or hasattr(module, "qkv_proj")
282-
)
283-
284-
def _is_mlp_module(module: Module):
285-
return "mlp" in module.__class__.__name__.lower() and (
286-
hasattr(module, "gate_proj") or hasattr(module, "up_proj")
287-
)
288-
289-
def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
290-
"""
291-
Return True if all the linear layers in the layer_list are
292-
NVFP4A16 quantized.
293-
"""
294-
for layer in layer_list:
295-
scheme = getattr(layer, "quantization_scheme", None)
296-
if scheme is None:
297-
return False
298-
299-
weight_quant_args = scheme.weights
300-
301-
if weight_quant_args is None:
302-
return False
303-
304-
if not is_fp4(quantization_args=weight_quant_args):
305-
return False
306-
return True
307-
308-
for name, submodule in iter_named_quantizable_modules(
309-
model,
310-
include_attn=True,
311-
include_mlp=True,
312-
):
313-
314-
if _is_attention_module(submodule):
315-
# already fused/treated as one layer
316-
if hasattr(submodule, "qkv_proj"):
317-
continue
318-
319-
if not _valid_fp4_quant(
320-
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
321-
):
322-
continue
323-
324-
q_weight = submodule.q_proj.weight.data
325-
v_weight = submodule.v_proj.weight.data
326-
k_weight = submodule.k_proj.weight.data
327-
328-
value = generate_global_scale(
329-
input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
330-
)
331-
332-
update_parameter_data(submodule.q_proj, value, "weight_global_scale")
333-
update_parameter_data(submodule.k_proj, value, "weight_global_scale")
334-
update_parameter_data(submodule.v_proj, value, "weight_global_scale")
335-
336-
if _is_mlp_module(submodule):
337-
if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
338-
continue
339-
340-
gate_data = submodule.gate_proj.weight.data
341-
up_data = submodule.up_proj.weight.data
342-
343-
value = generate_global_scale(
344-
input_tensor=torch.cat((gate_data, up_data), dim=0)
345-
)
346-
347-
update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
348-
update_parameter_data(submodule.up_proj, value, "weight_global_scale")

0 commit comments

Comments
 (0)