diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py index 9777553433..622f077045 100644 --- a/benchmarks/float8/bench_padding.py +++ b/benchmarks/float8/bench_padding.py @@ -50,10 +50,10 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) a_config = ScaledMMConfig( - emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True ) b_config = ScaledMMConfig( - emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True ) a_config = LinearMMConfig(a_config, a_config, a_config) b_config = LinearMMConfig(b_config, b_config, b_config) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 2a875c44d6..c8e1d98245 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -494,7 +494,7 @@ def test_different_configs_error(self): "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @pytest.mark.parametrize("use_fast_accum", [True, False]) - def test_pad_inner_dim(self, base_dtype, use_fast_accum): + def test_pad_dimensions(self, base_dtype, use_fast_accum): torch.manual_seed(42) input_dtype = torch.float8_e4m3fn compare_type = torch.float32 diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec3..997d1b2961 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -40,14 +40,25 @@ def _test_compile_base( fullgraph: bool, config: Float8LinearConfig, dtype: torch.dtype, + pad_dimensions: bool, ): random.seed(0) torch.manual_seed(0) - x_shape = (16, 16) + + if pad_dimensions: + x_shape = (17, 17) + else: + x_shape = (16, 16) + linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + + if pad_dimensions: + m_ref = nn.Linear(17, 35, bias=True, device="cuda", dtype=linear_dtype) + else: + m_ref = nn.Linear(16, 16, bias=True, device="cuda", dtype=linear_dtype) + m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), @@ -71,6 +82,7 @@ def _get_config( scaling_type_weight, scaling_type_grad_output, emulate, + pad_dimensions, ): if scaling_type_input is ScalingType.STATIC: cast_config_input = CastConfig( @@ -99,11 +111,13 @@ def _get_config( cast_config_weight=cast_config_weight, cast_config_grad_output=cast_config_grad_output, emulate=emulate, + pad_dimensions=pad_dimensions, ) return config @pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -113,7 +127,9 @@ def _get_config( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize( + "pad_dimensions", [False, True] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -122,17 +138,19 @@ def test_eager_only( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + pad_dimensions: bool, dtype: torch.dtype, ): torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, ) _test_compile_base( "eager", fullgraph, config, dtype, + pad_dimensions, ) @@ -147,6 +165,9 @@ def test_eager_only( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "pad_dimensions", [False, True] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( @@ -155,17 +176,19 @@ def test_aot_eager( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + pad_dimensions: bool, dtype: torch.dtype, ): torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, ) _test_compile_base( "aot_eager", fullgraph, config, dtype, + pad_dimensions, ) @@ -180,6 +203,9 @@ def test_aot_eager( @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) +@pytest.mark.parametrize( + "pad_dimensions", [True, False] +) @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( @@ -188,17 +214,19 @@ def test_inductor( scaling_type_input: ScalingType, scaling_type_weight: ScalingType, scaling_type_grad_output: ScalingType, + pad_dimensions: bool, dtype: torch.dtype, ): torch._dynamo.reset() config = _get_config( - scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, + scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions, ) _test_compile_base( "inductor", fullgraph, config, dtype, + pad_dimensions, ) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eb28dcbd8e..b00a449f8c 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -120,7 +120,7 @@ class Float8LinearConfig: # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. # This can cause a memory spike however so we keep this off by default. - pad_inner_dim: bool = False + pad_dimensions: bool = False # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb0..b52445015d 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -141,21 +141,21 @@ def __init__(self, *args, **kwargs): emulate, self.config.gemm_config_output.use_fast_accum, False, - self.config.pad_inner_dim, + self.config.pad_dimensions, ), # grad_input ScaledMMConfig( emulate, self.config.gemm_config_grad_input.use_fast_accum, False, - self.config.pad_inner_dim, + self.config.pad_dimensions, ), # grad_weight ScaledMMConfig( emulate, self.config.gemm_config_grad_weight.use_fast_accum, False, - self.config.pad_inner_dim, + self.config.pad_dimensions, ), ) diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index f8115649b3..a9e3740ce8 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -134,6 +134,8 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_scale = a._scale b_data = b._data + out_shape = (a._data.size(0), b._data.size(1)) + scaled_mm_config = choose_scaled_mm_config( a._gemm_input_role, a._linear_mm_config, @@ -141,20 +143,25 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): b._linear_mm_config, ) - if scaled_mm_config.pad_inner_dim: + if scaled_mm_config.pad_dimensions: assert a._data.size(1) == b._data.size( 0 ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" a_data = pad_tensor_for_matmul(a_data, dims=1) - b_data = pad_tensor_for_matmul(b_data, dims=0) + b_data = pad_tensor_for_matmul(b_data, dims=[0,1]) if not is_row_major(a_data.stride()): a_data = a_data.contiguous() if is_row_major(b_data.stride()): b_data = b_data.t().contiguous().t() b_scale = b._scale - return a_data, a_scale, b_data, b_scale + return a_data, a_scale, b_data, b_scale, out_shape + +def postprocess_addmm(out: torch.Tensor, scaled_mm_config, out_shape): + if scaled_mm_config.pad_dimensions: + out = out[:, :out_shape[1]] + return out @implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): @@ -166,7 +173,7 @@ def float8_mm(aten_op, args, kwargs=None): ), "Expecting both Float8Tensor for mm inputs but found {} and {}".format( type(a), type(b) ) - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b) output_dtype = a._orig_dtype scaled_mm_config = choose_scaled_mm_config( a._gemm_input_role, @@ -188,6 +195,7 @@ def float8_mm(aten_op, args, kwargs=None): bias=None, use_fast_accum=scaled_mm_config.use_fast_accum, ) + tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape) return tensor_out @@ -201,7 +209,7 @@ def float8_addmm(aten_op, args, kwargs=None): bias = args[0] a = args[1] b = args[2] - a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) + a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" scaled_mm_config = choose_scaled_mm_config( @@ -225,6 +233,7 @@ def float8_addmm(aten_op, args, kwargs=None): bias=bias, use_fast_accum=scaled_mm_config.use_fast_accum, ) + tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape) return tensor_out diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 63110101a5..784789d922 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -53,14 +53,14 @@ class ScaledMMConfig(NamedTuple): emulate (bool): Whether to emulate the matmuls in fp32. use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm. fp8_output (bool): Whether to output the result of the scaled_mm in fp8. - pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s. + pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16. """ emulate: bool = False use_fast_accum: bool = False fp8_output: bool = False - pad_inner_dim: bool = False + pad_dimensions: bool = False class LinearMMConfig(NamedTuple): diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 816a55ee61..80d81454f5 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -22,13 +22,13 @@ class Float8MMConfig(NamedTuple): Attributes: emulate (bool): Whether to emulate the matmuls in fp32. use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm. - pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s. + pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16. """ emulate: bool = False use_fast_accum: bool = False - pad_inner_dim: bool = False + pad_dimensions: bool = False def preprocess_data( @@ -44,7 +44,7 @@ def preprocess_data( Returns: Preprocessed tensors A and B in the format for _scaled_mm. """ - if scaled_mm_config.pad_inner_dim: + if scaled_mm_config.pad_dimensions: assert a_data.size(1) == b_data.size( 0 ), f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}"