diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e3f19d8f49..6354f7a9c4 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -62,13 +62,21 @@ def run( ): device = "cuda" # TODO(future PR): this is ugly - assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported" + assert recipe in ( + "tensorwise", + "rowwise", + "mxfp8_cublas", + "deepgemm_128_1_128_128", + ), "unsupported" specs = get_specs() bf16_peak_tops = specs["bf16_peak_tops"] fp8_peak_tops = specs["fp8_peak_tops"] print(f"gpu_name: {torch.cuda.get_device_name(0)}") print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}") + # TODO(this PR): make gpu kernel time work with deepgemm kernel + print(f"use_gpu_kernel_time: {use_gpu_kernel_time}") + print(f"recipe: {recipe}") headers = ( "fast_accum", @@ -121,16 +129,31 @@ def run( elif recipe == "mxfp8_cublas": scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe == "deepgemm_128_1_128_128": + scale_a = torch.ones(M, K // 128, device=device) + scale_b = torch.ones(N // 128, K // 128, device=device) else: assert False, f"unknown recipe {recipe}" - def do_matmul(A, B): - nonlocal scale_a - nonlocal scale_b - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + if recipe == "deepgemm_128_1_128_128": + from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import ( + scaled_mm_deep_gemm_128_1_128_128, ) + def do_matmul(A, B): + nonlocal scale_a + nonlocal scale_b + return scaled_mm_deep_gemm_128_1_128_128(A, B.t(), scale_a, scale_b) + + else: + + def do_matmul(A, B): + nonlocal scale_a + nonlocal scale_b + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) + fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B ) diff --git a/test/prototype/deep_gemm_float8_training/test_base.py b/test/prototype/deep_gemm_float8_training/test_base.py new file mode 100644 index 0000000000..0f996c99a5 --- /dev/null +++ b/test/prototype/deep_gemm_float8_training/test_base.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import random + +import pytest +import torch + +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_7, +) + +if not TORCH_VERSION_AT_LEAST_2_7: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import ( + scale_narrow_tiles, + scale_square_tiles, + scaled_mm_deep_gemm_128_1_128_1, + scaled_mm_deep_gemm_128_1_128_128, + unscale_narrow_tiles, + unscale_square_tiles, +) +from torchao.prototype.deep_gemm_float8_training.linear import ( + DeepGemmFloat8Linear, + DeepGemmFloat8LinearConfig, +) +from torchao.quantization import quantize_ + +random.seed(0) +torch.manual_seed(0) + + +class TestDeepGemmUtils: + @pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)]) + def test_128_1_128_128_gemm(self, mkn): + M, K, N = mkn + tile_size = 128 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + xq, xs = scale_narrow_tiles(x, tile_size=tile_size) + wq, ws = scale_square_tiles(w, tile_size=tile_size) + y = scaled_mm_deep_gemm_128_1_128_128(xq, wq, 1.0 / xs, 1.0 / ws) + y_ref = x @ w.T + sqnr = compute_error(y_ref, y) + assert sqnr > 26.0 + + @pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)]) + def test_128_1_128_1_gemm(self, mkn): + M, K, N = mkn + tile_size = 128 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + g = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + xq, xs = scale_narrow_tiles(x, tile_size=tile_size) + gq, gs = scale_narrow_tiles(g, tile_size=tile_size) + gi = scaled_mm_deep_gemm_128_1_128_1(xq, gq, 1.0 / xs, 1.0 / gs) + gi_ref = x @ g.T + sqnr = compute_error(gi_ref, gi) + assert sqnr > 27.0 + + def test_scale_square_tiles(self): + h, w = 8, 8 + tile_size = 4 + + x = torch.arange(h * w, device="cuda").float().reshape(h, w) + xq, s = scale_square_tiles(x, tile_size=tile_size) + xqdq = unscale_square_tiles(xq, s, tile_size=tile_size) + sqnr = compute_error(x, xqdq) + assert sqnr >= 25.0 + + def test_scale_narrow_tiles(self): + h, w = 8, 16 + tile_size = 4 + + x = torch.arange(h * w, device="cuda").float().reshape(h, w) + xq, s = scale_narrow_tiles(x, tile_size=tile_size) + xqdq = unscale_narrow_tiles(xq, s, tile_size=tile_size) + sqnr = compute_error(x, xqdq) + assert sqnr >= 32.0 + + +class TestDeepGemmLinear: + @pytest.mark.parametrize("x_rank", [2, 3]) + def test_hello_world(self, x_rank): + M, K, N = 128, 256, 512 + + x_ref = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + while len(x_ref.shape) < x_rank: + x_ref = x_ref.unsqueeze(0) + x_ref.requires_grad_() + + m_ref = torch.nn.Linear(K, N, bias=False).bfloat16().cuda() + go_ref = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + while len(go_ref.shape) < x_rank: + go_ref = go_ref.unsqueeze(0) + + x = copy.deepcopy(x_ref).requires_grad_() + m = copy.deepcopy(m_ref) + go = copy.deepcopy(go_ref) + + m = DeepGemmFloat8Linear.from_float(m) + + y_ref = m_ref(x_ref) + y_ref.backward(go_ref) + y = m(x) + y.backward(go) + + sqnr_y = compute_error(y_ref, y) + sqnr_gi = compute_error(x_ref.grad, x.grad) + sqnr_gw = compute_error(m_ref.weight.grad, m.weight.grad) + assert sqnr_y >= 25.0 + assert sqnr_gi >= 25.0 + assert sqnr_gw >= 25.0 + + def test_api(self): + m = torch.nn.Sequential(torch.nn.Linear(128, 128, bias=False)) + quantize_(m, config=DeepGemmFloat8LinearConfig()) + assert type(m[0]) == DeepGemmFloat8Linear + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchao/prototype/deep_gemm_float8_training/deep_gemm_utils.py b/torchao/prototype/deep_gemm_float8_training/deep_gemm_utils.py new file mode 100644 index 0000000000..032bb33d2b --- /dev/null +++ b/torchao/prototype/deep_gemm_float8_training/deep_gemm_utils.py @@ -0,0 +1,110 @@ +# TODO gate by existence of deep_gemm library +import deep_gemm +import torch + + +def scaled_mm_deep_gemm_128_1_128_128(a, b, a_scale, b_scale): + M, K = a.shape + N, K = b.shape + out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device) + deep_gemm.gemm_fp8_fp8_bf16_nt((a, a_scale), (b, b_scale), out=out) + return out + + +def scaled_mm_deep_gemm_128_1_128_1(a, b, a_scale, b_scale): + M, K = a.shape + N, K = b.shape + # Note: the results from `wgrad_gemm_fp8_fp8_fp32_nt` are **accumulated** + # into this tensor. For now, we initialize with `zeros` to get correct + # numerics in toy examples. For a real use case, this will need to pass + # in the gradient tensor directly. + out = torch.zeros((M, N), dtype=torch.float, device=a.device) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((a, a_scale), (b, b_scale), out=out) + return out + + +def scale_narrow_tiles(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled 1 by tile_size + + This is one function because logically this should be a fused kernel. + """ + # TODO assert row major + orig_shape = x.shape + x = x.reshape(-1, tile_size) + x_amax = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4) + # TODO read from finfo instead of hardcoding + s = 448.0 / x_amax + + x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn) + x = x.reshape(*orig_shape) + s = s.reshape(orig_shape[0], -1).to(torch.float) + return x, s + + +def unscale_narrow_tiles(x, s, tile_size=128): + # for debugging + orig_shape = x.shape + x = x.reshape(-1, tile_size) + s = s.reshape(-1).unsqueeze(1) + x = x.to(torch.float) / s + x = x.reshape(*orig_shape) + return x + + +def scale_square_tiles(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled tile_size by tile_size + + This is one function because logically this should be a fused kernel. + `torch.compile` currently has three kernels, we should write a triton + to speed this up kernel and file an issue for compile to catch up. + """ + # TODO assert row major + assert len(x.shape) == 2, "unsupported" + height, width = x.shape + + # might be funky with dynamic shapes... + t_h = height // tile_size + t_w = width // tile_size + x = x.reshape(t_h, tile_size, t_w, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(-1, tile_size * tile_size) + m = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4) + + # convert to scale + # TODO read from finfo instead of hardcoding + s = 448.0 / m + + x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn) + x = x.reshape(t_h, t_w, tile_size, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(height, width) + s = s.reshape(t_h, t_w).to(torch.float) + + return x, s + + +def unscale_square_tiles(x, s, tile_size=128): + # for debugging + + assert len(x.shape) == 2, "unsupported" + height, width = x.shape + + # might be funky with dynamic shapes... + t_h = height // tile_size + t_w = width // tile_size + x = x.reshape(t_h, tile_size, t_w, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(-1, tile_size * tile_size) + + s = s.reshape(-1).unsqueeze(1) + + x = x.float() / s + + x = x.reshape(t_h, t_w, tile_size, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(height, width) + return x diff --git a/torchao/prototype/deep_gemm_float8_training/linear.py b/torchao/prototype/deep_gemm_float8_training/linear.py new file mode 100644 index 0000000000..f3b0ca4b67 --- /dev/null +++ b/torchao/prototype/deep_gemm_float8_training/linear.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import ( + scale_narrow_tiles, + scale_square_tiles, + scaled_mm_deep_gemm_128_1_128_1, + scaled_mm_deep_gemm_128_1_128_128, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + + +@torch._dynamo.allow_in_graph +class deep_gemm_float8_fw_bw(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp: torch.Tensor, + ): + ctx.save_for_backward(input_hp, weight_hp) + input_orig_shape = input_hp.shape + input_hp = input_hp.reshape(-1, input_orig_shape[-1]) + assert input_hp.shape[-1] % 128 == 0, "unsupported" + + # cast input to float8 + input_fp8, input_scale = scale_narrow_tiles(input_hp, tile_size=128) + + # cast weight to float8 and save for bw + weight_fp8, weight_scale = scale_square_tiles(weight_hp, tile_size=128) + + # float8 gemm + output = scaled_mm_deep_gemm_128_1_128_128( + input_fp8, weight_fp8, 1.0 / input_scale, 1.0 / weight_scale + ) + output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) + return output + + @staticmethod + def backward(ctx, grad_output): + input_hp, weight_hp = ctx.saved_tensors + weight_hp_t = weight_hp.t().contiguous() + + input_orig_shape = input_hp.shape + input_hp = input_hp.reshape(-1, input_orig_shape[-1]) + + grad_output_orig_shape = grad_output.shape + grad_output = grad_output.reshape(-1, grad_output_orig_shape[-1]) + assert grad_output.shape[1] % 128 == 0, "unsupported" + + grad_output_fp8_dim0, grad_output_scale_dim0 = scale_narrow_tiles( + grad_output, tile_size=128 + ) + # TODO reuse from forward instead of casting again + weight_fp8, weight_scale = scale_square_tiles(weight_hp_t, tile_size=128) + grad_input = scaled_mm_deep_gemm_128_1_128_128( + grad_output_fp8_dim0, + weight_fp8, + 1.0 / grad_output_scale_dim0, + 1.0 / weight_scale, + ) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] + ) + + if False: + # passes unit tests, but broken in torchtitan with + # https://gist.github.com/vkuzo/3a763e150dbb37e5b917833a460f7f92 + grad_output_fp8_dim1, grad_output_scale_dim1 = scale_narrow_tiles( + grad_output.t().contiguous(), tile_size=128 + ) + input_hp_fp8_dim1, input_hp_scale_dim1 = scale_narrow_tiles( + input_hp.t().contiguous(), tile_size=128 + ) + grad_weight = scaled_mm_deep_gemm_128_1_128_1( + grad_output_fp8_dim1, + input_hp_fp8_dim1, + 1.0 / grad_output_scale_dim1, + 1.0 / input_hp_scale_dim1, + ) + grad_weight = grad_weight.to(grad_output.dtype) + else: + # workaround - leave this gemm in bf16 + grad_weight = grad_output.t() @ input_hp + + return grad_input, grad_weight + + +class DeepGemmFloat8Linear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = deep_gemm_float8_fw_bw.apply( + input, + self.weight, + ) + # TODO add bias support + return output + + @classmethod + def from_float( + cls, + mod, + ): + assert mod.bias is None, "unsupported" + assert mod.in_features % 128 == 0, "unsupported" + assert mod.out_features % 128 == 0, "unsupported" + with torch.device("meta"): + new_mod = cls( + mod.in_features, + mod.out_features, + bias=False, + ) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class DeepGemmFloat8LinearConfig(AOBaseConfig): + pass + + +@register_quantize_module_handler(DeepGemmFloat8LinearConfig) +def _deep_gemm_float8_inference_linear_transform(module, config): + return DeepGemmFloat8Linear.from_float(module)