Skip to content

Commit 36204f0

Browse files
committed
update quant/dequant steps; update scale calculation step
1 parent 79437ef commit 36204f0

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def quantize(
5050
args: QuantizationArgs,
5151
dtype: Optional[torch.dtype] = None,
5252
g_idx: Optional[torch.Tensor] = None,
53+
global_scale: Optional[torch.Tensor] = None,
5354
) -> torch.Tensor:
5455
"""
5556
Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -76,6 +77,7 @@ def quantize(
7677
do_quantize=True,
7778
do_dequantize=False,
7879
g_idx=g_idx,
80+
global_scale=global_scale,
7981
)
8082

8183

@@ -87,6 +89,7 @@ def dequantize(
8789
args: Optional[QuantizationArgs] = None,
8890
dtype: Optional[torch.dtype] = None,
8991
g_idx: Optional[torch.Tensor] = None,
92+
global_scale: Optional[torch.Tensor] = None,
9093
) -> torch.Tensor:
9194
"""
9295
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -129,6 +132,7 @@ def dequantize(
129132
do_dequantize=True,
130133
dtype=dtype,
131134
g_idx=g_idx,
135+
global_scale=global_scale,
132136
)
133137

134138

@@ -139,6 +143,7 @@ def fake_quantize(
139143
zero_point: torch.Tensor,
140144
args: QuantizationArgs,
141145
g_idx: Optional[torch.Tensor] = None,
146+
global_scale: Optiona[torch.Tensor] = None,
142147
) -> torch.Tensor:
143148
"""
144149
Fake quantize the input tensor x by quantizing then dequantizing with
@@ -162,6 +167,7 @@ def fake_quantize(
162167
do_quantize=True,
163168
do_dequantize=True,
164169
g_idx=g_idx,
170+
global_scale=global_scale,
165171
)
166172

167173

@@ -175,6 +181,7 @@ def _process_quantization(
175181
dtype: Optional[torch.dtype] = None,
176182
do_quantize: bool = True,
177183
do_dequantize: bool = True,
184+
global_scale: Optional[torch.Tensor] = None,
178185
) -> torch.Tensor:
179186
q_min, q_max = calculate_range(args, x.device)
180187
group_size = args.group_size
@@ -222,35 +229,44 @@ def _process_quantization(
222229
end = start + group_count
223230
if do_quantize:
224231
output[:, start:end] = _quantize(
225-
x[:, start:end],
226-
sc,
227-
zp,
228-
q_min,
229-
q_max,
230-
args,
232+
x=x[:, start:end],
233+
scale=sc,
234+
zero_point=zp,
235+
q_min=q_min,
236+
q_max=q_max,
237+
args=args,
231238
dtype=dtype,
239+
global_scale=global_scale,
232240
)
233241

234242
if do_dequantize:
235243
input = output[:, start:end] if do_quantize else x[:, start:end]
236-
output[:, start:end] = _dequantize(input, sc, zp)
244+
output[:, start:end] = _dequantize(
245+
x=input, scale=sc, zero_point=zp, global_scale=global_scale
246+
)
237247

238248
if not is_column_order:
239249
output = safe_permute(output, torch.argsort(perm), dim=1)
240250

241251
else: # covers channel, token and tensor strategies
242252
if do_quantize:
243253
output = _quantize(
244-
x,
245-
scale,
246-
zero_point,
247-
q_min,
248-
q_max,
249-
args,
254+
x=x,
255+
scale=scale,
256+
zero_point=zero_point,
257+
q_min=q_min,
258+
q_max=q_max,
259+
args=args,
250260
dtype=dtype,
261+
global_scale=global_scale,
251262
)
252263
if do_dequantize:
253-
output = _dequantize(output if do_quantize else x, scale, zero_point)
264+
output = _dequantize(
265+
output if do_quantize else x,
266+
scale=scale,
267+
zero_point=zero_point,
268+
global_scale=global_scale,
269+
)
254270

255271
return output
256272

@@ -331,6 +347,7 @@ def forward_quantize(
331347
return value
332348

333349
g_idx = getattr(module, "weight_g_idx", None)
350+
global_scale = getattr(module, f"{base_name}_global_scale", None)
334351

335352
if args.dynamic:
336353
# dynamic quantization - determine the scale/zp on the fly
@@ -346,6 +363,7 @@ def forward_quantize(
346363
zero_point=zero_point,
347364
args=args,
348365
g_idx=g_idx,
366+
global_scale=global_scale,
349367
)
350368

351369

@@ -358,11 +376,16 @@ def _quantize(
358376
q_max: torch.Tensor,
359377
args: QuantizationArgs,
360378
dtype: Optional[torch.dtype] = None,
379+
global_scale: Optional[torch.Tensor] = None,
361380
) -> torch.Tensor:
362381

382+
if global_scale:
383+
scale = scale.to(global_scale.dtype) * global_scale
384+
363385
scaled = x / scale
364386
if zero_point is not None:
365387
scaled += zero_point.to(x.dtype)
388+
366389
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
367390
clamped_value = torch.clamp(
368391
scaled,
@@ -382,8 +405,12 @@ def _dequantize(
382405
scale: torch.Tensor,
383406
zero_point: torch.Tensor = None,
384407
dtype: Optional[torch.dtype] = None,
408+
global_scale: Optional[torch.Tensor] = None,
385409
) -> torch.Tensor:
386410

411+
if global_scale:
412+
scale = scale.to(global_scale.dtype) * global_scale
413+
387414
dequant_value = x_q.to(scale.dtype)
388415

389416
if zero_point is not None:

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def _initialize_scale_zero_point(
167167

168168
# NVFP4 support; use FP8 scales
169169
# For weight quant, attach global scales for NVFP4
170+
# TODO: How do we know if we need a global scale?
171+
# TODO: NVFP4 Scheme
170172
if (
171173
base_name == "weight"
172174
and quantization_args.num_bits == 4
@@ -177,6 +179,7 @@ def _initialize_scale_zero_point(
177179
tensor_amax = torch.abs(module.weight.data).max().to(torch.float32)
178180
# Setting data for now - could possibly be handled later in the pipeline
179181
value = FP8_E4M3_DATA.max * FP4_NVFP4_DATA.max / tensor_amax
182+
# use the weight dtype (bfloat) maybe use float32 to start?
180183
value = value.to(torch.float32).to(device)
181184
# Assuming the global scale can be torch.float16/bfloat16/module weight dtype and not only torch.float32?
182185
init_global_scale = Parameter(value, requires_grad=False)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class FloatArgs:
5656
min=torch.finfo(torch.float8_e4m3fn).min,
5757
dtype=torch.float8_e4m3fn,
5858
)
59+
# Don't call NVFP4; should be based on exponent and mantissa
5960
FP4_NVFP4_DATA = FloatArgs(exponent=2, mantissa=1, bits=4, max=6.0, min=-6.0)
6061

6162

@@ -265,6 +266,7 @@ def pytorch_dtype(self) -> torch.dtype:
265266
return FP8_E4M3_DATA.dtype
266267
else:
267268
assert self.num_bits == 4
269+
# TODO: Use the look-up?
268270
# TODO: will return None for now until updated in FloatArgs
269271
return FP4_NVFP4_DATA.dtype
270272
elif self.type == QuantizationType.INT:
@@ -299,6 +301,7 @@ def round_to_quantized_type(
299301
rounded = tensor.to(FP8_E4M3_DATA.dtype)
300302
else:
301303
assert args.num_bits == 4
304+
# TODO: Use the FP4_NVFP4_DATA class to use a look-up table
302305
# TODO: cast to whatever value we want fp4 to be post quantization/clamping
303306
rounded = tensor.to(FP4_NVFP4_DATA.dtype)
304307
elif args.type == QuantizationType.INT:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555

5656

5757
def calculate_qparams(
58-
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
58+
min_vals: Tensor,
59+
max_vals: Tensor,
60+
quantization_args: QuantizationArgs,
61+
global_scale: Optional[Tensor] = None,
5962
) -> Tuple[FloatTensor, IntTensor]:
6063
"""
6164
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
@@ -81,17 +84,19 @@ def calculate_qparams(
8184
quantization_args.num_bits == 4
8285
and quantization_args.type == QuantizationType.FLOAT
8386
):
87+
assert global_scale is not None
8488
# TODO: how do we pass in the global scale?
8589
# An observer is attached per module, we can conditionally pass in
86-
# the global scale
87-
scale = global_scale * (max_val_pos / FP4_NVFP4_DATA.max)
88-
scale = scale.to(FP8_E4M3_DATA.dtype).to(torch.float32)
90+
# the global scale --> TODO: check for presence of the global when updating the scale
91+
# TODO: maybe remove FP8 scale cast
92+
scale = max_val_pos / FP4_NVFP4_DATA.max
8993
scale = scale / global_scale
94+
scale = scale.to(FP8_E4M3_DATA.dtype) # .to(torch.float32)
95+
9096
else:
9197
# Divide over bit range over max value?
9298
scales = max_val_pos / (float(bit_range) / 2)
9399

94-
# needed for fp4?
95100
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
96101
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
97102
else:

0 commit comments

Comments
 (0)