Skip to content

Commit ddb7f83

Browse files
authored
Revert "Remove torchao.quantization.prototype" (#1919)
Revert "Remove torchao.quantization.prototype (#1889)" This reverts commit 576cf6b.
1 parent d456ea1 commit ddb7f83

File tree

10 files changed

+175
-0
lines changed

10 files changed

+175
-0
lines changed

test/quantization/test_qat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,62 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
11331133
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
11341134
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
11351135

1136+
@unittest.skipIf(
1137+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1138+
)
1139+
def test_qat_prototype_bc(self):
1140+
"""
1141+
Just to make sure we can import all the old prototype paths.
1142+
We will remove this test in the near future when we actually break BC.
1143+
"""
1144+
from torchao.quantization.prototype.qat import ( # noqa: F401, F811, I001
1145+
disable_4w_fake_quant,
1146+
disable_8da4w_fake_quant,
1147+
enable_4w_fake_quant,
1148+
enable_8da4w_fake_quant,
1149+
ComposableQATQuantizer,
1150+
Int8DynActInt4WeightQATLinear,
1151+
Int4WeightOnlyEmbeddingQATQuantizer,
1152+
Int4WeightOnlyQATQuantizer,
1153+
Int8DynActInt4WeightQATQuantizer,
1154+
)
1155+
from torchao.quantization.prototype.qat._module_swap_api import ( # noqa: F401, F811
1156+
disable_4w_fake_quant_module_swap,
1157+
enable_4w_fake_quant_module_swap,
1158+
disable_8da4w_fake_quant_module_swap,
1159+
enable_8da4w_fake_quant_module_swap,
1160+
Int4WeightOnlyQATQuantizerModuleSwap,
1161+
Int8DynActInt4WeightQATQuantizerModuleSwap,
1162+
)
1163+
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( # noqa: F401, F811
1164+
AffineFakeQuantizedTensor,
1165+
to_affine_fake_quantized,
1166+
)
1167+
from torchao.quantization.prototype.qat.api import ( # noqa: F401, F811
1168+
ComposableQATQuantizer,
1169+
FakeQuantizeConfig,
1170+
)
1171+
from torchao.quantization.prototype.qat.embedding import ( # noqa: F401, F811
1172+
FakeQuantizedEmbedding,
1173+
Int4WeightOnlyEmbeddingQATQuantizer,
1174+
Int4WeightOnlyEmbedding,
1175+
Int4WeightOnlyQATEmbedding,
1176+
)
1177+
from torchao.quantization.prototype.qat.fake_quantizer import ( # noqa: F401, F811
1178+
FakeQuantizer,
1179+
)
1180+
from torchao.quantization.prototype.qat.linear import ( # noqa: F401, F811
1181+
disable_4w_fake_quant,
1182+
disable_8da4w_fake_quant,
1183+
enable_4w_fake_quant,
1184+
enable_8da4w_fake_quant,
1185+
FakeQuantizedLinear,
1186+
Int4WeightOnlyQATLinear,
1187+
Int4WeightOnlyQATQuantizer,
1188+
Int8DynActInt4WeightQATLinear,
1189+
Int8DynActInt4WeightQATQuantizer,
1190+
)
1191+
11361192
@unittest.skipIf(
11371193
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
11381194
)

torchao/quantization/prototype/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Note: QAT has been moved to torchao/quantization/qat.
2+
This is a legacy folder only for backward compatibility
3+
and will be removed in the near future.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from torchao.quantization.qat import (
2+
ComposableQATQuantizer,
3+
Int4WeightOnlyEmbeddingQATQuantizer,
4+
Int4WeightOnlyQATQuantizer,
5+
Int8DynActInt4WeightQATQuantizer,
6+
)
7+
from torchao.quantization.qat.linear import (
8+
Int8DynActInt4WeightQATLinear,
9+
disable_4w_fake_quant,
10+
disable_8da4w_fake_quant,
11+
enable_4w_fake_quant,
12+
enable_8da4w_fake_quant,
13+
)
14+
15+
__all__ = [
16+
"disable_4w_fake_quant",
17+
"disable_8da4w_fake_quant",
18+
"enable_4w_fake_quant",
19+
"enable_8da4w_fake_quant",
20+
"ComposableQATQuantizer",
21+
"Int4WeightOnlyQATQuantizer",
22+
"Int4WeightOnlyEmbeddingQATQuantizer",
23+
"Int8DynActInt4WeightQATQuantizer",
24+
"Int8DynActInt4WeightQATLinear",
25+
]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# For backward compatibility only
2+
# These will be removed in the future
3+
4+
from torchao.quantization.qat.linear import (
5+
Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap,
6+
)
7+
from torchao.quantization.qat.linear import (
8+
Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap,
9+
)
10+
from torchao.quantization.qat.linear import (
11+
disable_4w_fake_quant as disable_4w_fake_quant_module_swap,
12+
)
13+
from torchao.quantization.qat.linear import (
14+
disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap,
15+
)
16+
from torchao.quantization.qat.linear import (
17+
enable_4w_fake_quant as enable_4w_fake_quant_module_swap,
18+
)
19+
from torchao.quantization.qat.linear import (
20+
enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap,
21+
)
22+
23+
__all__ = [
24+
"Int8DynActInt4WeightQATQuantizerModuleSwap",
25+
"Int4WeightOnlyQATQuantizerModuleSwap",
26+
"enable_8da4w_fake_quant_module_swap",
27+
"disable_8da4w_fake_quant_module_swap",
28+
"enable_4w_fake_quant_module_swap",
29+
"disable_4w_fake_quant_module_swap",
30+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from torchao.quantization.qat.affine_fake_quantized_tensor import (
2+
AffineFakeQuantizedTensor,
3+
to_affine_fake_quantized,
4+
)
5+
6+
__all__ = [
7+
"AffineFakeQuantizedTensor",
8+
"to_affine_fake_quantized",
9+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from torchao.quantization.qat.api import (
2+
ComposableQATQuantizer,
3+
FakeQuantizeConfig,
4+
)
5+
6+
__all__ = [
7+
"ComposableQATQuantizer",
8+
"FakeQuantizeConfig",
9+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from torchao.quantization.qat.embedding import (
2+
FakeQuantizedEmbedding,
3+
Int4WeightOnlyEmbedding,
4+
Int4WeightOnlyEmbeddingQATQuantizer,
5+
Int4WeightOnlyQATEmbedding,
6+
)
7+
8+
__all__ = [
9+
"FakeQuantizedEmbedding",
10+
"Int4WeightOnlyEmbeddingQATQuantizer",
11+
"Int4WeightOnlyEmbedding",
12+
"Int4WeightOnlyQATEmbedding",
13+
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torchao.quantization.qat.fake_quantizer import (
2+
FakeQuantizer,
3+
)
4+
5+
__all__ = [
6+
"FakeQuantizer",
7+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from torchao.quantization.qat.linear import (
2+
FakeQuantizedLinear,
3+
Int4WeightOnlyQATLinear,
4+
Int4WeightOnlyQATQuantizer,
5+
Int8DynActInt4WeightQATLinear,
6+
Int8DynActInt4WeightQATQuantizer,
7+
disable_4w_fake_quant,
8+
disable_8da4w_fake_quant,
9+
enable_4w_fake_quant,
10+
enable_8da4w_fake_quant,
11+
)
12+
13+
__all__ = [
14+
"disable_4w_fake_quant",
15+
"disable_8da4w_fake_quant",
16+
"enable_4w_fake_quant",
17+
"enable_8da4w_fake_quant",
18+
"FakeQuantizedLinear",
19+
"Int4WeightOnlyQATLinear",
20+
"Int4WeightOnlyQATQuantizer",
21+
"Int8DynActInt4WeightQATLinear",
22+
"Int8DynActInt4WeightQATQuantizer",
23+
]

0 commit comments

Comments
 (0)