diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index b0cee1e918..f17264d734 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -25,6 +25,7 @@ MXInferenceLinear, MXLinear, ) +from torchao.prototype.mx_formats.mx_subclass import MXFPConfig from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error from torchao.utils import ( @@ -372,3 +373,34 @@ def test_inference_print_str(): s = str(m) assert "bl_sz=32" in s assert "kernel=emulated" in s + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100") +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +@torch.no_grad() +def test_inference_subclass(elem_dtype, bias: bool, compile: bool): + """ + Smoke test for inference compile + """ + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not is_sm_at_least_89(): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + + m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") + m_mx = copy.deepcopy(m) + config = MXFPConfig() + quantize_(m_mx, config=config) + if compile: + m_mx = torch.compile(m_mx, fullgraph=True) + + x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" diff --git a/torchao/__init__.py b/torchao/__init__.py index fb96282d77..730cd326fe 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -43,7 +43,7 @@ quantize_, ) -from . import dtypes, optim, swizzle, testing +from . import dtypes, optim, quantization, swizzle, testing __all__ = [ "dtypes", @@ -53,4 +53,5 @@ "swizzle", "testing", "ops", + "quantization", ] diff --git a/torchao/core/config.py b/torchao/core/config.py index fe03ac225b..c9ee24f68c 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -175,6 +175,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.quantization", "torchao.sparsity.sparse_api", "torchao.prototype.quantization", + "torchao.prototype.mx_formats", } diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index 7252c33dc9..e707a4889f 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -4,6 +4,7 @@ MXLinearConfig, MXLinearRecipeName, ) +from torchao.prototype.mx_formats.mx_subclass import MXFPConfig # import mx_linear here to register the quantize_ transform logic # ruff: noqa: I001 @@ -14,4 +15,5 @@ "MXInferenceLinearConfig", "MXLinearConfig", "MXLinearRecipeName", + "MXFPConfig", ] diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index af2d89c112..3e25d2667b 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -17,9 +17,12 @@ the underlying data fields to the MX matmul. """ -from typing import Any, Dict +from typing import Any, Dict, Optional import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) from torch.utils._pytree import tree_map import torchao.ops @@ -35,10 +38,12 @@ tensor_size_hpx3_to_fp6x4, ) from torchao.prototype.mx_formats.utils import to_blocked +from torchao.utils import fill_defaults aten = torch.ops.aten MX_OPS_TABLE: Dict[Any, Any] = {} +MX_FUNCTION_TABLE: Dict[Any, Any] = {} def implements(aten_ops): @@ -52,59 +57,88 @@ def decorator(func): return decorator -@implements([aten.detach.default]) -def mx_desugar_op(aten_op, args, kwargs=None): - old = args[0] - new_data = aten_op(old._data, *args[1:], **kwargs) - new = MXTensor( - old._scale_e8m0, - new_data, - old._elem_dtype, - old._block_size, - old._orig_dtype, - old._use_fp4_custom_triton_dequant_kernel, - old._gemm_kernel_choice, - old._pack_fp6, +def implements_func(torch_ops): + """Register torch ops to the mx op table for torch function""" + + def decorator(func): + for op in torch_ops: + MX_FUNCTION_TABLE[op] = func + return func + + return decorator + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(func) ) - return new -@implements([aten.mm.default, aten.matmul.default]) -def mx_mm(aten_op, args, kwargs=None): - a = args[0] - b = args[1] - assert isinstance(a, MXTensor) and isinstance(b, MXTensor) - assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" - if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): +def _get_gemm_choice( + choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] +) -> MXGemmKernelChoice: + if choice_a is not None and choice_b is not None: + assert choice_a == choice_b, ( + "Both MXTensor inputs must have the same gemm config if specified" + ) + return choice_a + + # Assert that at least one is set and return that one + assert choice_a is not None or choice_b is not None, ( + "At least one gemm choice must be specified" + ) + return choice_a if choice_a is not None else choice_b + + +def _addmm_mx_dispatch( + a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Core implementation shared between mx_mm and mx_addmm. + The only difference is whether bias is None or not. + """ + gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) + + if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a._data.is_contiguous() assert b._data.t().is_contiguous() + assert a._block_size == 32, f"Invalid block size {a._block_size}" + assert b._block_size == 32, f"Invalid block size {b._block_size}" - # TODO(future PR): use block_size instead of hardcoding 32 - a_scale = a._scale_e8m0.view(M, K // 32) - b_scale = b._scale_e8m0.view(N, K // 32) + a_scale = a._scale_e8m0.view(M, K // a._block_size) + b_scale = b._scale_e8m0.view(N, K // b._block_size) a_scale_block = to_blocked(a_scale) b_scale_block = to_blocked(b_scale) + if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS, ( + assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( "CUBLAS is the only supported kernel choice for MX FP8 operations" ) + res = torch._scaled_mm( a._data, b._data, a_scale_block.view(torch.float8_e8m0fnu), b_scale_block.view(torch.float8_e8m0fnu), + bias=bias, out_dtype=torch.bfloat16, ) else: assert a._elem_dtype == DTYPE_FP4 assert b._elem_dtype == DTYPE_FP4 - assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" + # FP4 operations res = torchao.ops.mx_fp4_bf16( a._data, b._data, a_scale_block, b_scale_block ) + # TODO add optional bias to kernel + if bias is not None: + res = res + bias + else: # emulated MX gemm a_hp = a.to_dtype(a._orig_dtype) @@ -112,12 +146,48 @@ def mx_mm(aten_op, args, kwargs=None): # assert memory layout we expect to be required in hardware assert a_hp.is_contiguous() assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) + + # Call appropriate aten_op based on whether bias is provided + if bias is not None: + res = aten_op(bias, a_hp, b_hp) # addmm + else: + res = aten_op(a_hp, b_hp) # mm + return res +@implements_func([aten.linear.default]) +def mx_linear(func, types, args, kwargs): + a, b = args[0], args[1] + assert isinstance(a, MXTensor) and isinstance(b, MXTensor) + bias = args[2] if len(args) == 3 else None + return _addmm_mx_dispatch(a, b.t(), func, bias=bias) + + +@implements([aten.mm.default, aten.matmul.default]) +def mx_mm(func, types, args, kwargs): + a = args[0] + b = args[1] + assert isinstance(a, MXTensor) and isinstance(b, MXTensor) + + return _addmm_mx_dispatch(a, b, func) + + +@implements([aten.addmm.default]) +def mx_addmm(func, types, args, kwargs): + assert ( + isinstance(args[0], torch.Tensor) + and isinstance(args[1], MXTensor) + and isinstance(args[2], MXTensor) + ) + bias = args[0] + a = args[1] + b = args[2] + return _addmm_mx_dispatch(a, b, func, bias=bias) + + @implements([aten.t.default]) -def mx_t(aten_op, args, kwargs=None): +def mx_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported. old = args[0] new = MXTensor( @@ -134,7 +204,7 @@ def mx_t(aten_op, args, kwargs=None): @implements([aten.sum.dim_IntList]) -def mx_cast_up_op(aten_op, args, kwargs=None): +def mx_cast_up_op(func, types, args, kwargs): """Be careful with this function, this is a "fallback" op that casts the output of the op to the original precision. And performs the op. @@ -150,11 +220,11 @@ def unwrap(x): new_args = tree_map(unwrap, args) new_kwargs = tree_map(unwrap, kwargs) - return aten_op(*new_args, **new_kwargs) + return func(*new_args, **new_kwargs) @implements([aten.view.default]) -def mx_view_op(aten_op, args, kwargs=None): +def mx_view_op(func, types, args, kwargs): data = args[0]._data new_size = args[1] if args[0]._elem_dtype == DTYPE_FP4: @@ -163,7 +233,7 @@ def mx_view_op(aten_op, args, kwargs=None): elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: # special case fp6 as we pack 4 elements in 3 bytes new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) - new_data = aten_op(data, new_size, *args[2:], **kwargs) + new_data = func(data, new_size, *args[2:], **kwargs) return MXTensor( args[0]._scale_e8m0, new_data, @@ -176,28 +246,120 @@ def mx_view_op(aten_op, args, kwargs=None): ) -@implements([aten._to_copy.default]) -def autocast_to_copy(aten_op, args, kwargs=None): - """This gets called when running matmul under autocast - when the input is a MXTensor, presenting as a fp32 - tensor. - """ - assert isinstance(args[0], MXTensor) - assert len(kwargs) == 1 and "dtype" in kwargs, ( - "Only support dtype kwarg for autocast" +@implements([aten.slice.Tensor]) +def mx_slice(func, types, args, kwargs): + x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + M, K = x.shape[0], x.shape[1] + + # TODO why doesn't scale have shape? + scale_shaped = x._scale_e8m0.view(M, K // x._block_size) + + if dim == 0: + # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step).flatten() + sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + elif dim == 1: + # Slicing along reduciton dim + if start is not None: + # Assert start is a multiple of block_size + assert start % x._block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x._block_size}" + ) + + if end is not None: + # Assert end is a multiple of block_size + assert end % x._block_size == 0, ( + f"End index {end} must be a multiple of block_size {x._block_size}" + ) + + sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + + # Calculate which scale elements to keep + start_block = 0 if start is None else start // x._block_size + end_block = -1 if end is None else end // x._block_size + + # Slice the scale tensor accordingly + sliced_scale = aten.slice.Tensor( + scale_shaped, 1, start_block, end_block, step + ).flatten() + else: + raise ValueError( + f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" + ) + + return return_and_correct_aliasing( + func, + args, + kwargs, + MXTensor( + sliced_scale, + sliced_data, + x._elem_dtype, + x._block_size, + x._orig_dtype, + x._use_fp4_custom_triton_dequant_kernel, + x._gemm_kernel_choice, + x._pack_fp6, + ), ) - assert kwargs["dtype"] in { - torch.float16, - torch.bfloat16, - }, "Only support floating point conversion for autocast w/ MXTensor" - res = MXTensor( - args[0]._scale_e8m0, - args[0]._data, - args[0]._elem_dtype, - args[0]._block_size, - kwargs["dtype"], - args[0]._use_fp4_custom_triton_dequant_kernel, - args[0]._gemm_kernel_choice, - args[0]._pack_fp6, + + +@implements([aten.copy_.default]) +def mx_copy_(func, types, args, kwargs): + self = args[0] + src = args[1] + if MXTensor._same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) - return res + + +@implements([aten._to_copy.default]) +def autocast_to_copy(func, types, args, kwargs): + """Autocast + device movement""" + assert isinstance(args[0], MXTensor) + + # Handle dtype parameter + dtype = kwargs.pop("dtype", None) + if dtype is not None: + assert dtype in { + torch.float16, + torch.bfloat16, + }, "Only support floating point conversion for autocast w/ MXTensor" + + # Handle device parameter + device = kwargs.pop("device", None) + if device is not None: + # Apply device change using _apply_fn_to_data + tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) + tensor = return_and_correct_aliasing(func, args, {}, tensor) + else: + tensor = args[0] + + # Verify no other kwargs remain + assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast" + + # If dtype is specified, create a new MXTensor with the requested dtype + if dtype is not None: + res = MXTensor( + tensor._scale_e8m0, + tensor._data, + tensor._elem_dtype, + tensor._block_size, + dtype, + tensor._use_fp4_custom_triton_dequant_kernel, + tensor._gemm_kernel_choice, + tensor._pack_fp6, + ) + return res + + # If only device was changed, return the device-changed tensor + return tensor diff --git a/torchao/prototype/mx_formats/mx_subclass.py b/torchao/prototype/mx_formats/mx_subclass.py new file mode 100644 index 0000000000..2c0897d195 --- /dev/null +++ b/torchao/prototype/mx_formats/mx_subclass.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import types +from dataclasses import dataclass +from typing import Optional + +import torch + +import torchao +from torchao.core.config import AOBaseConfig +from torchao.prototype.mx_formats import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.config import ( + _validate_elem_dtype, + _validate_gemm_kernel_choice, +) +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.quantization.quant_api import to_linear_activation_quantized +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100 + + +@dataclass +class MXFPConfig(AOBaseConfig): + block_size: int = 32 + + # Dtypes for Input and Weights + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + + # Which kernel to run for mm + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS + + # Set some magic perf settings + set_inductor_config: bool = True + + def __post_init__(self): + assert self.activation_dtype == self.weight_dtype, ( + "For now - we only support matching input/weight dtypes." + ) + _validate_elem_dtype(self.activation_dtype) + _validate_elem_dtype(self.weight_dtype) + _validate_gemm_kernel_choice( + self.gemm_kernel_choice, self.block_size, self.weight_dtype + ) + + +def _linear_extra_repr(self): + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" + + +def _input_activation_quant_func_mxfp( + x: torch.Tensor, + activation_dtype: torch.dtype, + block_size: int, + scale: Optional[torch.Tensor] = None, +): + """ """ + + # TODO scale for static quant + + activation = MXTensor.to_mx( + x, + activation_dtype, + block_size=block_size, + gemm_kernel_choice=None, # Get from weight + pack_fp6=False, # TODO + ) + return activation + + +@register_quantize_module_handler(MXFPConfig) +def _mx_inference_linear_transform(module: torch.nn.Module, config: MXFPConfig): + # TODO Sm120 has slightly more restrictive reqs + # TODO handle AMD + assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + weight = module.weight + + assert weight.dtype == torch.bfloat16, ( + f"Only supporting bf16 out dtype for now, got {weight.dtype}" + ) + + # Convert weight to MX Tensor + quantized_weight = MXTensor.to_mx( + weight, + weight_dtype, + block_size=config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + pack_fp6=False, # TODO + ) + + input_quant_func = _input_activation_quant_func_mxfp + input_quant_kwargs = { + "block_size": config.block_size, + "activation_dtype": activation_dtype, + "scale": None, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals( + [MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp] + ) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index f3aca15a73..c4245d303f 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -18,7 +18,7 @@ """ from enum import Enum, auto -from typing import Dict, Union +from typing import Callable, Dict, Union import torch @@ -559,13 +559,27 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # avoid circular dependency - from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE + from torchao.prototype.mx_formats.mx_ops import MX_FUNCTION_TABLE, MX_OPS_TABLE if func in MX_OPS_TABLE: - return MX_OPS_TABLE[func](func, args, kwargs) + return MX_OPS_TABLE[func](func, types, args, kwargs) + + # TODO AO BASE_TENSOR doesn't respect dispatch and function modes + # I suspect thtat since we call to funcitonal.linear for the inear wrapper + # tensor and some modes are disabled in __dispatch__ we see this op + if func == torch.ops.aten.linear.default: + return MX_FUNCTION_TABLE[func](func, types, args, kwargs) raise NotImplementedError(f"{func} not implemented") + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # from torchao.prototype.mx_formats.mx_ops import MX_FUNCTION_TABLE + + # if func in MX_FUNCTION_TABLE: + # return MX_FUNCTION_TABLE[func](func, types, args, kwargs) + # raise NotImplementedError(f"{func} not implemented") + def to_dtype(self, target_dtype): return to_dtype( self._data, @@ -631,5 +645,37 @@ def __tensor_unflatten__( metadata["_pack_fp6"], ) + def _apply_fn_to_data(self, fn: Callable): + """Applies a fn to all tensor components stored on this class""" + tensor_names, ctx = self.__tensor_flatten__() + + # Apply the function to each tensor component + new_tensors = {} + for name in tensor_names: + new_tensors[name] = fn(getattr(self, name)) + + return self.__class__.__tensor_unflatten__( + new_tensors, + ctx, + None, # outer_size parameter + None, # outer_stride parameter + ) + # Do not force the MXTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def _same_metadata(cls, self: "MXTensor", src: "MXTensor") -> bool: + return ( + isinstance(self, MXTensor) + and isinstance(src, MXTensor) + and self._elem_dtype == src._elem_dtype + and self._block_size == src._block_size + and self._orig_dtype == src._orig_dtype + and self._use_fp4_custom_triton_dequant_kernel + == src._use_fp4_custom_triton_dequant_kernel + and self._gemm_kernel_choice == src._gemm_kernel_choice + and self._pack_fp6 == src._pack_fp6 + and self._scale_e8m0.shape == src._scale_e8m0.shape + and self._data.shape == src._data.shape + ) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 8b186f82d6..67c914c124 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -35,13 +35,13 @@ def to_blocked(input_matrix) -> Tensor: padded_cols = n_col_blocks * 4 padded = input_matrix - if (rows, cols) != (padded_rows, padded_cols): - padded = torch.zeros( - (padded_rows, padded_cols), - device=input_matrix.device, - dtype=input_matrix.dtype, - ) - padded[:rows, :cols] = input_matrix + # if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)