Skip to content

Commit 1c0ea5b

Browse files
authored
Fix float related autoquant options (#1562)
* Fix float related autoquant options Summary: Forgot to add a test for previous changes, this fixed some implementations for the quantized model Test Plan: python test/integration/test_integration.py -k test_autoquant_float Reviewers: Subscribers: Tasks: Tags: * skip non-cuda runs * update torch version requirement * typo
1 parent 71c6231 commit 1c0ea5b

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

test/integration/test_integration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,42 @@ def test_autoquant_min_sqnr(self, device, dtype):
17471747
# setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr
17481748
self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}")
17491749

1750+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1751+
@unittest.skipIf(
1752+
not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
1753+
)
1754+
def test_autoquant_float(self):
1755+
device = "cuda"
1756+
dtype = torch.float32
1757+
m, k, n = 128, 128, 128
1758+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1759+
model = (
1760+
torch.nn.Sequential(
1761+
torch.nn.ReLU(),
1762+
torch.nn.Linear(k, n),
1763+
torch.nn.ReLU(),
1764+
)
1765+
.to(device)
1766+
.to(dtype)
1767+
)
1768+
ref = model(example_input)
1769+
torchao.autoquant(
1770+
model,
1771+
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1772+
)
1773+
out = model(example_input)
1774+
from torchao.quantization.autoquant import (
1775+
BFloat16Tensor,
1776+
Float16Tensor,
1777+
Float32Tensor,
1778+
)
1779+
1780+
self.assertIn(
1781+
type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor]
1782+
)
1783+
print(compute_error(out, ref))
1784+
self.assertGreater(compute_error(out, ref), 60)
1785+
17501786

17511787
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
17521788
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")

torchao/_models/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def write_json_result_ossci(output_json_path, headers, row):
3535
"arch": mapping_headers["arch"],
3636
"min_sqnr": mapping_headers["min_sqnr"],
3737
# True means compile is enabled, False means eager mode
38-
"complie": mapping_headers["compile"],
38+
"compile": mapping_headers["compile"],
3939
},
4040
},
4141
"model": {
@@ -87,7 +87,7 @@ def write_json_result_local(output_json_path, headers, row):
8787
"arch": mapping_headers["arch"],
8888
"min_sqnr": mapping_headers["min_sqnr"],
8989
# True means compile is enabled, False means eager mode
90-
"complie": mapping_headers["compile"],
90+
"compile": mapping_headers["compile"],
9191
},
9292
},
9393
"model": {

torchao/quantization/autoquant.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _apply_fn_to_data(self, fn):
778778

779779
@classmethod
780780
def from_float(cls, weight):
781-
return cls(weight)
781+
return Float32Tensor(weight)
782782

783783

784784
@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
@@ -829,6 +829,10 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
829829
bias.to(_DTYPE) if bias is not None else bias,
830830
).to(dtype=orig_dtype)
831831

832+
@classmethod
833+
def from_float(cls, weight):
834+
return BFloat16Tensor(weight)
835+
832836

833837
class Float16Tensor(Float32Tensor):
834838
def __init__(self, weight):
@@ -844,6 +848,10 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
844848
bias.to(_DTYPE) if bias is not None else bias,
845849
).to(dtype=orig_dtype)
846850

851+
@classmethod
852+
def from_float(cls, weight):
853+
return Float16Tensor(weight)
854+
847855

848856
class AQFloat32LinearWeight(Float32Tensor, AQMixin):
849857
"""

0 commit comments

Comments
 (0)