Skip to content

Commit 8a51e1a

Browse files
authored
Fix bfloat16/float16/float32 options (#1369)
* Fix bfloat16/float16/float32 options Summary: There was some problems with previous implementation of bfloat16/float16/float32 since it does not convert activation to the correct dtype after quantization, this PR fixes it Test Plan: llama: ``` python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-fp ``` same2: ``` server: python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant client: time xargs -I {} curl -s -w "\n" -X POST http://localhost:8000/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000 ``` Reviewers: Subscribers: Tasks: Tags: * ruff
1 parent 63d142c commit 8a51e1a

File tree

4 files changed

+266
-58
lines changed

4 files changed

+266
-58
lines changed

examples/sam2_amg_server/server.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def load_aot_fast(mask_generator, model_directory):
468468
pkg = torch._inductor.aoti_load_package(str(path))
469469
pkg_m = LoadedModel(pkg)
470470
mask_generator.predictor.model.image_encoder = pkg_m
471-
471+
472472
# NOTE: This doesn't work yet!
473473
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2"))
474474
# pkg_m = LoadedModel(pkg)
@@ -526,6 +526,18 @@ def set_furious(mask_generator):
526526
# NOTE: Not baseline feature
527527
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
528528

529+
def set_autoquant(mask_generator):
530+
from torchao import autoquant
531+
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
532+
# NOTE: Not baseline feature
533+
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
534+
mask_generator.predictor._transforms_device = mask_generator.predictor.device
535+
torch.set_float32_matmul_precision('high')
536+
# NOTE: this fails when we run
537+
# python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest
538+
# https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e
539+
# mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
540+
529541

530542
def main(checkpoint_path,
531543
model_type,
@@ -590,14 +602,7 @@ def main(checkpoint_path,
590602
set_furious(mask_generator)
591603
# since autoquant is replicating what furious mode is doing, don't use these two together
592604
elif use_autoquant:
593-
from torchao import autoquant
594-
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
595-
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
596-
597-
# mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
598-
# NOTE: Not baseline feature
599-
mask_generator.predictor._transforms_device = mask_generator.predictor.device
600-
torch.set_float32_matmul_precision('high')
605+
set_autoquant(mask_generator)
601606

602607
with open('dog.jpg', 'rb') as f:
603608
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))

torchao/_models/llama/generate.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,19 @@ def main(
357357
)
358358

359359
if "autoquant_v2-int4" == quantization:
360-
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
360+
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
361361
elif "autoquant_v2-float8" == quantization:
362-
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
362+
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
363+
elif "autoquant_v2-fp" == quantization:
364+
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
365+
elif "autoquant_v2-all" == quantization:
366+
all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
367+
if torchao.utils.is_sm_89():
368+
# this is fp8 related subclasses, should rename
369+
all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST
370+
model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length)
363371
else:
364-
model = autoquant_v2(model, manual=True, example_input=inputs)
372+
model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length)
365373

366374
print("running generate")
367375
generate(
@@ -406,6 +414,12 @@ def main(
406414
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
407415
if "autoquant-fp" == quantization:
408416
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs)
417+
if "autoquant-all" == quantization:
418+
all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
419+
if torchao.utils.is_sm_89():
420+
# this is fp8 related subclasses, should rename
421+
all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
422+
model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs)
409423
else:
410424
model = autoquant(model, manual=True, example_input=inputs)
411425

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from torchao.utils import (
3131
TORCH_VERSION_AT_LEAST_2_3,
3232
TORCH_VERSION_AT_LEAST_2_5,
33-
benchmark_model,
33+
TorchAOBaseTensor,
3434
)
3535

3636
from torchao.quantization.granularity import (
@@ -61,6 +61,7 @@
6161
"autoquant_v2",
6262
"DEFAULT_AUTOQUANT_CLASS_LIST",
6363
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
64+
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
6465
"OTHER_AUTOQUANT_CLASS_LIST",
6566
"_is_linear",
6667
]
@@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
288289
)
289290
elif (self.logged_data == {}) and not error_on_unseen:
290291
# default back to non-quantized weight if not seen
291-
self = AQFloatLinearWeight.from_float(self.weight)
292+
self = AQDefaultLinearWeight.from_float(self.weight)
292293
return self
293294

294295
# only want to print shape (at start) and final result (at end)
@@ -360,7 +361,7 @@ def count_shapes(self, do_print=True):
360361
print(f"best_cls={best_cls}\n")
361362
# TODO handle random cls args/kwargs? or should they be curried?
362363
if best_cls is None:
363-
best_cls = AQFloatLinearWeight
364+
best_cls = AQDefaultLinearWeight
364365

365366
self = best_cls.from_float(self.weight)
366367
return self
@@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
802803
group_size: int = 256
803804

804805

805-
class AQFloatLinearWeight(torch.Tensor, AQMixin):
806+
class AQDefaultLinearWeight(torch.Tensor, AQMixin):
806807
"""
807808
A class to be used in concert with AutoQuantizableLinearWeight to provide a
808809
default/non-quantized option. Only implements the bare minimum needed to work with the
@@ -823,6 +824,130 @@ def from_float(cls, weight):
823824
return weight
824825

825826

827+
class Float32Tensor(TorchAOBaseTensor):
828+
""" Tensor subclass tensor for fp32 dtype
829+
"""
830+
def __init__(self, weight):
831+
self.weight = weight.to(torch.float32)
832+
833+
@staticmethod
834+
def _quantized_linear_op(act_mat, w_qtensor, bias):
835+
_DTYPE = torch.float32
836+
orig_dtype = act_mat.dtype
837+
return torch.nn.functional.linear(
838+
act_mat.to(_DTYPE),
839+
w_qtensor.weight,
840+
bias.to(_DTYPE) if bias is not None else bias,
841+
).to(dtype=orig_dtype)
842+
843+
def _apply_fn_to_data(self, fn):
844+
return self.__class__(
845+
fn(self.weight),
846+
)
847+
848+
@classmethod
849+
def from_float(cls, weight):
850+
return cls(weight)
851+
852+
@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
853+
def _(func, types, args, kwargs):
854+
input_tensor, weight_tensor, bias = (
855+
args[0],
856+
args[1],
857+
args[2] if len(args) > 2 else None,
858+
)
859+
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
860+
861+
@Float32Tensor.implements(aten.detach.default)
862+
def _(func, types, args, kwargs):
863+
return return_and_correct_aliasing(
864+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
865+
)
866+
867+
868+
@Float32Tensor.implements(aten.clone.default)
869+
def _(func, types, args, kwargs):
870+
return return_and_correct_aliasing(
871+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
872+
)
873+
874+
875+
@Float32Tensor.implements(aten._to_copy.default)
876+
def _(func, types, args, kwargs):
877+
return return_and_correct_aliasing(
878+
func,
879+
args,
880+
kwargs,
881+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
882+
)
883+
884+
885+
class BFloat16Tensor(Float32Tensor):
886+
def __init__(self, weight):
887+
self.weight = weight.to(torch.bfloat16)
888+
889+
@staticmethod
890+
def _quantized_linear_op(act_mat, w_qtensor, bias):
891+
_DTYPE = torch.bfloat16
892+
orig_dtype = act_mat.dtype
893+
return torch.nn.functional.linear(
894+
act_mat.to(_DTYPE),
895+
w_qtensor.weight,
896+
bias.to(_DTYPE) if bias is not None else bias,
897+
).to(dtype=orig_dtype)
898+
899+
900+
class Float16Tensor(Float32Tensor):
901+
def __init__(self, weight):
902+
self.weight = weight.to(torch.float16)
903+
904+
@staticmethod
905+
def _quantized_linear_op(act_mat, w_qtensor, bias):
906+
_DTYPE = torch.float16
907+
orig_dtype = act_mat.dtype
908+
return torch.nn.functional.linear(
909+
act_mat.to(_DTYPE),
910+
w_qtensor.weight,
911+
bias.to(_DTYPE) if bias is not None else bias,
912+
).to(dtype=orig_dtype)
913+
914+
915+
class AQFloat32LinearWeight(Float32Tensor, AQMixin):
916+
"""
917+
AutoQuantizable version for float32 precision weight
918+
919+
(also converts input activation and bias to float32, and restores the original precision after
920+
linear)
921+
"""
922+
@classmethod
923+
def from_float(cls, weight):
924+
return super(AQFloat32LinearWeight, cls).from_float(weight)
925+
926+
927+
class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin):
928+
"""
929+
AutoQuantizable version for bfloat16 precision weight
930+
931+
(also converts input activation and bias to bfloat16, and restores the original precision after
932+
linear)
933+
"""
934+
@classmethod
935+
def from_float(cls, weight):
936+
return super(AQBFloat16LinearWeight, cls).from_float(weight)
937+
938+
939+
class AQFloat16LinearWeight(Float16Tensor, AQMixin):
940+
"""
941+
AutoQuantizable version for float16 precision weight
942+
943+
(also converts input activation and bias to float16, and restores the original precision after
944+
linear)
945+
"""
946+
@classmethod
947+
def from_float(cls, weight):
948+
return super(AQFloat16LinearWeight, cls).from_float(weight)
949+
950+
826951
class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
827952
"""
828953
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
@@ -936,7 +1061,7 @@ def get_weight_block_size(x):
9361061

9371062
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
9381063
DEFAULT_AUTOQUANT_CLASS_LIST = [
939-
AQFloatLinearWeight,
1064+
AQDefaultLinearWeight,
9401065
AQInt8WeightOnlyQuantizedLinearWeight,
9411066
AQInt8WeightOnlyQuantizedLinearWeight2,
9421067
# AQInt8WeightOnlyQuantizedLinearWeight3,
@@ -945,11 +1070,17 @@ def get_weight_block_size(x):
9451070
]
9461071

9471072
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
948-
AQFloatLinearWeight,
1073+
AQDefaultLinearWeight,
9491074
AQInt8DynamicallyQuantizedLinearWeight,
9501075
AQInt4G64WeightOnlyQuantizedLinearWeight,
9511076
]
9521077

1078+
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [
1079+
AQFloat32LinearWeight,
1080+
AQBFloat16LinearWeight,
1081+
AQFloat16LinearWeight,
1082+
]
1083+
9531084
OTHER_AUTOQUANT_CLASS_LIST = [
9541085
AQFloat8WeightOnlyQuantizedLinearWeight,
9551086
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,

0 commit comments

Comments
 (0)