Skip to content

Add all fbgemm kernel Tensors into Int4WeightOnlyConfig and Float8DynamicActivationInt4WeightConfig #2474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/integration/test_serialization_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from torchao.utils import _is_fbgemm_genai_gpu_available, is_sm_at_least_90

_MODEL_NAMES = [
"torchao-testing/opt-125m-float8dq-row-fbgemm",
"torchao-testing/opt-125m-float8dq-row-0.13-dev",
"torchao-testing/opt-125m-int4wo-preshuffled-0.13-dev",
]


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skip("temporary skip since we have some refactor next")
class TestSerializationBC(TestCase):
"""Test we can still load and run serialized model in previous AO versions
we commit to have BC for 3 pytorch releases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
Float8ActivationInt4WeightConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -27,44 +27,15 @@
is_sm_at_least_90,
)

if TORCH_VERSION_AT_LEAST_2_8:
BF16_ACT_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

BF16_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

FP8_ACT_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
preshuffle=True,
)

FP8_ACT_BMM_CONFIG = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
preshuffle=True,
)

else:
BF16_ACT_CONFIG = None
BF16_ACT_BMM_CONFIG = None
FP8_ACT_CONFIG = None
FP8_ACT_BMM_CONFIG = None
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="preshuffled",
)

FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig(
group_size=128,
packing_format="preshuffled",
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
Expand All @@ -90,7 +61,7 @@ def test_linear(self, config):

# Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
# @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG])
@parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG])
def test_bmm(self, bmm_config):
class M(torch.nn.Module):
def __init__(self, weight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from torchao.quantization import (
FbgemmConfig,
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
Expand All @@ -26,19 +26,11 @@
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestFbgemmInt4Tensor(TestCase):
class TestInt4Tensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
self.config = Int4WeightOnlyConfig(
group_size=128,
packing_format="plain",
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

Expand Down Expand Up @@ -68,13 +60,9 @@ def test_slice(self):
quantize_(dummy, self.config)
weight1 = dummy.weight.narrow(0, 0, 64)
weight2 = dummy.weight.narrow(1, 0, 128)
self.assertEqual(
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
)
self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
self.assertEqual(
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
)
self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64))
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))

# check for sliced weight, before and after float8 quantization
Expand All @@ -100,12 +88,10 @@ def test_slice_and_copy_(self):
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, 512)
assert (
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
)
assert param.data._data.data_ptr() == param_data._data.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
orig_value = param.data.packed_weight[0][0].item()
orig_value = param.data._data[0][0].item()

# dummy_l has random input (shouldn't be 0)
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
Expand All @@ -116,7 +102,7 @@ def test_slice_and_copy_(self):
param_data.copy_(quantized)

# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value
assert param.data._data[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
Expand All @@ -135,7 +121,7 @@ def forward(self, x):
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

Expand Down
3 changes: 0 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to_affine_quantized_intx_static,
)
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -64,8 +63,6 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
"Int8DynamicActInt4WeightCPULayout",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .quant_api import (
CutlassInt4PackedLayout,
FbgemmConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan for FbgemmConfig? Looks like it was added only ~1.5 months ago but it's technically public API. Do we know if anyone's using it already? I don't think it's released yet so wonder if it's OK to just remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll remove it, it is used in some internal script but we'll update these as well

Float8ActivationInt4WeightConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand Down Expand Up @@ -90,6 +91,7 @@
from .quantize_.workflows import (
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
)
from .smoothquant import (
SmoothFakeDynamicallyQuantizedLinear,
Expand Down Expand Up @@ -141,6 +143,7 @@
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int4WeightOnlyConfig",
"Float8ActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
Expand All @@ -154,6 +157,7 @@
"ModuleFqnToConfig",
"FbgemmConfig",
# tensor subclasses
"Int4Tensor",
"Int4PreshuffledTensor",
"Float8Tensor",
# smooth quant - subject to change
Expand Down
74 changes: 71 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
to_affine_quantized_floatx_static,
to_affine_quantized_intx,
to_fbgemm_fp8,
to_fbgemm_int4,
to_marlinqqq_quantized_intx,
)
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
Expand All @@ -72,6 +71,8 @@
from torchao.quantization.quantize_.workflows import (
Float8Tensor,
Int4PreshuffledTensor,
Int4Tensor,
PackingFormat,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1108,6 +1109,8 @@ class Int4WeightOnlyConfig(AOBaseConfig):
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
`packing_format`: the packing format for int4 tensor, this is the new structure we plan to use
existing layout based structure will use `_legacy`
"""

group_size: int = 128
Expand All @@ -1116,6 +1119,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
set_inductor_config: bool = True
preserve_zero: Optional[bool] = None
# since not all tensors are migrated to the new structure yet,
# we use `_legacy' to represent the previous layout
packing_format: PackingFormat = "_legacy"


# for BC
Expand All @@ -1133,15 +1139,33 @@ def _int4_weight_only_quantize_tensor(weight, config):
layout = config.layout
use_hqq = config.use_hqq
zero_point_domain = config.zero_point_domain
packing_format = config.packing_format

if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight

block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])

if packing_format == "preshuffled":
new_weight = Int4PreshuffledTensor.from_float(
weight,
block_size,
activation_dtype=torch.bfloat16,
)
return new_weight
elif packing_format == "plain":
new_weight = Int4Tensor.from_float(
weight,
block_size,
)
return new_weight

assert packing_format == "_legacy", f"Unsupported packing_format: {packing_format}"

mapping_type = MappingType.ASYMMETRIC
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size])
target_dtype = torch.int32
quant_min = 0
quant_max = 15
Expand Down Expand Up @@ -1213,6 +1237,38 @@ def _int4_weight_only_transform(
return module


@dataclass
class Float8ActivationInt4WeightConfig(AOBaseConfig):
group_size: int = 128
packing_format: PackingFormat = "preshuffled"


@register_quantize_module_handler(Float8ActivationInt4WeightConfig)
def _float8_activation_int4_weight_transform(
module: torch.nn.Module, config: Float8ActivationInt4WeightConfig
) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying int8 weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
group_size = config.group_size
packing_format = config.packing_format

assert packing_format == "preshuffled", (
f"only preshuffled packing_format supported right now, got: {packing_format}"
)
weight = module.weight
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
new_weight = Int4PreshuffledTensor.from_float(
module.weight,
block_size,
activation_dtype=torch.float8_e4m3fn,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Int8WeightOnlyConfig(AOBaseConfig):
"""
Expand Down Expand Up @@ -1484,6 +1540,7 @@ class Float8WeightOnlyConfig(AOBaseConfig):
"""

weight_dtype: torch.dtype = e4m3_dtype
packing_format: PackingFormat = "plain"
set_inductor_config: bool = True


Expand All @@ -1493,6 +1550,10 @@ class Float8WeightOnlyConfig(AOBaseConfig):

def _float8_weight_only_quant_tensor(weight, config):
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]])
packing_format = config.packing_format
assert packing_format == "plain", (
f"Only plain packing_format is supported, got: {packing_format}"
)
new_weight = Float8Tensor.from_float(
weight,
config.weight_dtype,
Expand Down Expand Up @@ -1610,6 +1671,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
mm_config: Optional[Float8MMConfig] = None
activation_scale_ub: float = 1200.0
packing_format: PackingFormat = "plain"
set_inductor_config: bool = True

def __post_init__(self):
Expand All @@ -1631,8 +1693,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
weight_dtype = config.weight_dtype
granularity = config.granularity
mm_config = config.mm_config
packing_format = config.packing_format
activation_scale_ub = config.activation_scale_ub

assert packing_format == "plain", (
f"Only plain packing_format is supported, got {packing_format}"
)

# Ensure works on device
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity
Expand All @@ -1641,6 +1708,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

if isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
Expand Down Expand Up @@ -2083,7 +2151,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
activation_dtype=torch.bfloat16,
)
else:
weight = to_fbgemm_int4(
weight = Int4Tensor.from_float(
module.weight,
config.block_size,
)
Expand Down
Loading
Loading