Skip to content

Commit 79437ef

Browse files
committed
update
1 parent 974953c commit 79437ef

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ def compress(
374374

375375
compressed_state_dict = state_dict
376376

377-
quantized_modules_to_args: Dict[
378-
str, QuantizationArgs
379-
] = map_modules_to_quant_args(model)
377+
quantized_modules_to_args: Dict[str, QuantizationArgs] = (
378+
map_modules_to_quant_args(model)
379+
)
380380

381381
if self.quantization_compressor is not None:
382382
compressed_state_dict = self.quantization_compressor.compress(

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
wrap_module_forward_quantized,
2323
)
2424
from compressed_tensors.quantization.quant_args import (
25+
FP4_NVFP4_DATA,
26+
FP8_E4M3_DATA,
2527
ActivationOrdering,
2628
QuantizationArgs,
2729
QuantizationStrategy,
@@ -31,8 +33,6 @@
3133
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3234
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
3335
from compressed_tensors.utils import (
34-
FP4_NVFP4_DATA,
35-
FP8_E4M3_DATA,
3636
disable_hf_hook,
3737
has_offloaded_params,
3838
register_offload_parameter,
@@ -176,11 +176,10 @@ def _initialize_scale_zero_point(
176176
# create and attach nvfp4 data
177177
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
178178
# Setting data for now - could possibly be handled later in the pipeline
179-
values = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax
179+
value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax
180+
value = value.to(torch.float32).to(device)
180181
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
181-
init_global_scale = Parameter(
182-
value, dtype=torch.float32, device=device, requires_grad=False
183-
)
182+
init_global_scale = Parameter(value, requires_grad=False)
184183
register_offload_parameter(
185184
module, f"f{base_name}_global_scale", init_global_scale
186185
)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@
3434
"ActivationOrdering",
3535
]
3636

37+
38+
@dataclass
39+
class FloatArgs:
40+
exponent: int
41+
mantissa: int
42+
bits: int
43+
max: float
44+
min: float
45+
dtype: Optional[torch.dtype] = None
46+
47+
3748
# TODO: Remove soon in favour of a more descriptive FloatArgs
3849
FP8_DTYPE = torch.float8_e4m3fn
3950

@@ -48,16 +59,6 @@
4859
FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
4960

5061

51-
@dataclass
52-
class FloatArgs:
53-
exponent: int
54-
mantissa: int
55-
bits: int
56-
max: float
57-
min: float
58-
dtype: Optional[torch.dtype] = None
59-
60-
6162
class QuantizationType(str, Enum):
6263
"""
6364
Enum storing quantization type options

0 commit comments

Comments
 (0)