From 87178e04cadebff0532293a2865f12be30f43072 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 15 May 2025 10:24:03 -0700 Subject: [PATCH 01/25] Add selective weight loading decode kernel for activation sparsity Summary: This PR adds in a kernel to accelerate $$xW^T$$ when x is sparse and we are memory bound. The idea here is that we can avoid loading the columns of $W$ that correspond to the zero elements of $x$. This lets us accelerate activation sparsity for bs=1 decode use cases. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/kernel/sparse_gemv/sparse_gemv.py | 139 ++++++++++++++++++ torchao/kernel/sparse_gemv/test_kernel.py | 42 ++++++ .../sparsity/activation/srelu_linear.py | 49 ++++++ 3 files changed, 230 insertions(+) create mode 100644 torchao/kernel/sparse_gemv/sparse_gemv.py create mode 100644 torchao/kernel/sparse_gemv/test_kernel.py diff --git a/torchao/kernel/sparse_gemv/sparse_gemv.py b/torchao/kernel/sparse_gemv/sparse_gemv.py new file mode 100644 index 0000000000..68a3871a56 --- /dev/null +++ b/torchao/kernel/sparse_gemv/sparse_gemv.py @@ -0,0 +1,139 @@ +# adapted from deja vu + +from typing import Optional + +import torch +import triton +import triton.language as tl +def init_to_zero(*names): + def init_func(nargs): + for name in names: + nargs[name].zero_() + return init_func + +# NOTE: will need to warm up kernels each time, triton autotune caching isn't a thing right now + +configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")), + + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), + + triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), + + + # Llama 3 variants can use BLOCK_N >= 1024 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), +] + +@triton.autotune( + configs=configs, + key=["CACHE_KEY_M", "CACHE_KEY_N", "BATCHSIZE"], +) +@triton.jit +def splitk_sparse_gemv_kernel( + Y, # Pointers to matrices + A, X, + # Matrix dimensions + N, M, + CACHE_KEY_N, CACHE_KEY_M, + # Meta-parameters + BATCHSIZE: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, +): + start_n = tl.program_id(0) + start_m = tl.program_id(1) + # now compute the block that each program will go through + # rn (resp. rm) denotes a range of indices for rows (resp. col) of A + + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + A_ptr = A + (rm[:, None] * N + rn[None, :]) + X_ptr = X + rm + Y_ptr = Y + rn + + # eviction policy go brrr + if BATCHSIZE == 1: + x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks + idx = x0 != 0 + # selectively load weight rows + a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock + acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0) + + # rematerialize rm and rn to save registers + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + tl.atomic_add(Y_ptr, acc0, mask=rn < N) + + +from torch.library import triton_op, wrap_triton + +# NOTE: assumes that weight is column major +@triton_op("torchao::splitk_sparse_gemv", mutates_args={}) +def splitk_sparse_gemv( + x: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Compute y = sparse(X) @ weight. + :param x: input tensor [1, 1, Z] + :param weight: weight matrix [N, Z] + :return: result tensor y + """ + N, Z = weight.shape + beam_width, seq_len, _ = x.shape + assert x.shape[2] == Z + x = x.contiguous() + + assert weight.stride(1) > 1, "weight should be column major" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(N, META["BLOCK_N"]), + triton.cdiv(Z, META["BLOCK_M"]), + ) # noqa + + output = torch.empty( + beam_width, + seq_len, + N, + device=x.device, + dtype=torch.float16, + ) + + + kernel = wrap_triton(splitk_sparse_gemv_kernel) + kernel[grid]( + output, # data ptrs + weight, + x, + N, # shapes + Z, + N // 16, # key for triton cache (limit number of compilations) + Z // 16, + beam_width, # BATCHSIZE + # can't use kwargs because auto-tuner requires args + ) + + if x.dtype is not output.dtype: + print(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") + return output.to(dtype=x.dtype) + + return output diff --git a/torchao/kernel/sparse_gemv/test_kernel.py b/torchao/kernel/sparse_gemv/test_kernel.py new file mode 100644 index 0000000000..38a1829c26 --- /dev/null +++ b/torchao/kernel/sparse_gemv/test_kernel.py @@ -0,0 +1,42 @@ +import torch +from sparse_gemv import splitk_sparse_gemv +from triton.testing import do_bench + +def create_binary_tensor(shape, percent_zeros): + """ + Creates a PyTorch tensor with a specific percentage of zeros and ones. + + Args: + shape (tuple): The shape of the tensor to create + percent_zeros (float): Percentage of zeros in the tensor (between 0 and 1) + + Returns: + torch.Tensor: A tensor with specified percentage of zeros and ones + """ + # Calculate the total number of elements + total_elements = torch.prod(torch.tensor(shape)).item() + + # Calculate number of zeros needed + num_zeros = int(total_elements * percent_zeros) + + # Create a vector of all ones + tensor = torch.ones(total_elements) + + # Randomly choose indices to set to zero + zero_indices = torch.randperm(total_elements)[:num_zeros] + tensor[zero_indices] = 0 + + # Reshape to the desired shape + tensor = tensor.reshape(shape) + + return tensor + +for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: + + a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(torch.float16) + b = torch.randn(16384, 4096).cuda().to(torch.float16).T.contiguous().T + + sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 + dense_time = do_bench(lambda: torch.matmul(a, b.T)) * 1e6 + speedup = dense_time / sparse_time + print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index f8c3288b67..51846c9933 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -13,6 +13,7 @@ from torchao.quantization.transform_module import ( register_quantize_module_handler, ) +from torchao.kernel.sparse_gemv.sparse_gemv import splitk_sparse_gemv @dataclass @@ -85,3 +86,51 @@ def from_dense( raise NotImplementedError("weight dtype must be bf16") return cls(linear.weight.data, config) + + +@dataclass +class ActivationSparseLinearConfig(AOBaseConfig): + """ + Adds in acceleration for activation sparsity to linear layers for decode. + + Args: + `activation_dtype`: data type for quantized activation tensor. + `weight_dtype`: data type for quantized weight tensor. + """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + +@register_quantize_module_handler( + ActivationSparseLinearConfig +) +def _activation_spare_linear_transform( + module: torch.nn.Module, + config: ActivationSparseLinearConfig, +): + return ActivationSparseLinear.from_dense(module, config) + + +class ActivationSparseLinear(nn.Module): + """ + Replacement nn.Linear that supports runtime fp8 activation sparsity + """ + + def __init__(self, weight, config) -> None: + super().__init__() + self.config = config + self.weight_transposed = weight.T.contiguous().T + + def forward(self, x): + if x.shape[1] == 1: + return torch.ops.torchao.splitk_sparse_gemv(x, self.weight_transposed) + else: + return torch.nn.functional.linear(x, self.weight_transposed) + + @classmethod + def from_dense( + cls, linear, config:ActivationSparseLinearConfig + ): + if linear.bias is not None: + raise NotImplementedError("bias is not supported") + return cls(linear.weight.data, config) From 17c953130ac5aee9c161087a992a06ff271a612f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 15 May 2025 14:40:19 -0700 Subject: [PATCH 02/25] update --- benchmarks/benchmark_splitk_sparse_gemv.py | 16 ++++ test/sparsity/test_activation24.py | 18 ++++- torchao/kernel/sparse_gemv/test_kernel.py | 42 ---------- .../sparse_gemv.py => splitk_sparse_gemv.py} | 80 ++++++++++--------- torchao/sparsity/utils.py | 21 +++++ 5 files changed, 95 insertions(+), 82 deletions(-) create mode 100644 benchmarks/benchmark_splitk_sparse_gemv.py delete mode 100644 torchao/kernel/sparse_gemv/test_kernel.py rename torchao/kernel/{sparse_gemv/sparse_gemv.py => splitk_sparse_gemv.py} (76%) diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py new file mode 100644 index 0000000000..27a08df987 --- /dev/null +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -0,0 +1,16 @@ +import torch + +from triton.testing import do_bench +from torchao.sparsity.utils import create_binary_tensor +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv + + +for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: + + a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(torch.float16) + b = torch.randn(16384, 4096).cuda().to(torch.float16).T.contiguous().T + + sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 + dense_time = do_bench(lambda: torch.matmul(a, b.T)) * 1e6 + speedup = dense_time / sparse_time + print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 65b7cfd8d2..1ab92a1aa1 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -18,7 +18,7 @@ SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, ) from torchao.sparsity import sparsify_ -from torchao.sparsity.utils import create_semi_structured_tensor +from torchao.sparsity.utils import create_semi_structured_tensor, create_binary_tensor from torchao.utils import is_sm_at_least_90 @@ -141,3 +141,19 @@ def srelu_linear(x): custom_output = reference_linear_copy(input_tensor) torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) + + + +def test_splitk_sparse_gemv(): + torch.manual_seed(0) + + activation = create_binary_tensor((1, 1, 1024), 0.5).cuda().to(torch.float16) + weight = torch.randn(1024, 1024, dtype=torch.float16).cuda() + + # weight must be column major + weight_transposed = weight.T.contiguous().T + + sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) + dense_res = F.linear(activation, weight_transposed) + + torch.testing.assert_close(sparse_res, dense_res, rtol=0.1, atol=0.01) diff --git a/torchao/kernel/sparse_gemv/test_kernel.py b/torchao/kernel/sparse_gemv/test_kernel.py deleted file mode 100644 index 38a1829c26..0000000000 --- a/torchao/kernel/sparse_gemv/test_kernel.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from sparse_gemv import splitk_sparse_gemv -from triton.testing import do_bench - -def create_binary_tensor(shape, percent_zeros): - """ - Creates a PyTorch tensor with a specific percentage of zeros and ones. - - Args: - shape (tuple): The shape of the tensor to create - percent_zeros (float): Percentage of zeros in the tensor (between 0 and 1) - - Returns: - torch.Tensor: A tensor with specified percentage of zeros and ones - """ - # Calculate the total number of elements - total_elements = torch.prod(torch.tensor(shape)).item() - - # Calculate number of zeros needed - num_zeros = int(total_elements * percent_zeros) - - # Create a vector of all ones - tensor = torch.ones(total_elements) - - # Randomly choose indices to set to zero - zero_indices = torch.randperm(total_elements)[:num_zeros] - tensor[zero_indices] = 0 - - # Reshape to the desired shape - tensor = tensor.reshape(shape) - - return tensor - -for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: - - a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(torch.float16) - b = torch.randn(16384, 4096).cuda().to(torch.float16).T.contiguous().T - - sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 - dense_time = do_bench(lambda: torch.matmul(a, b.T)) * 1e6 - speedup = dense_time / sparse_time - print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/torchao/kernel/sparse_gemv/sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py similarity index 76% rename from torchao/kernel/sparse_gemv/sparse_gemv.py rename to torchao/kernel/splitk_sparse_gemv.py index 68a3871a56..897c92ae83 100644 --- a/torchao/kernel/sparse_gemv/sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -1,45 +1,48 @@ -# adapted from deja vu +""" +This code is adapted from https://github.com/FasterDecoding/TEAL/blob/main/kernels/sparse_gemv.py + +Since we already have sparse activations from ReLU, we can get rid of the thresholding step and just use the sparse tensor directly. +""" from typing import Optional import torch import triton import triton.language as tl -def init_to_zero(*names): - def init_func(nargs): - for name in names: - nargs[name].zero_() - return init_func +from torch.library import triton_op, wrap_triton + +def init_to_zero(*args, **kwargs): + # print(type) + args[0]["Y"].zero_() # NOTE: will need to warm up kernels each time, triton autotune caching isn't a thing right now configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")), - - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero("Y")), - - triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")), - - - # Llama 3 variants can use BLOCK_N >= 1024 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), + + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), + + triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), + + + # # Llama 3 variants can use BLOCK_N >= 1024 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), ] @triton.autotune( @@ -71,19 +74,18 @@ def splitk_sparse_gemv_kernel( # eviction policy go brrr if BATCHSIZE == 1: - x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks - idx = x0 != 0 + x0 = tl.load(X_ptr, mask=rm < M, other=0, eviction_policy='evict_last') # reuse x across threadblocks + idx = (x0 != 0) # selectively load weight rows - a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock + a = tl.load(A_ptr, mask=idx[:, None], other=0, eviction_policy='evict_first') # only load weights once per threadblock acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0) # rematerialize rm and rn to save registers rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - tl.atomic_add(Y_ptr, acc0, mask=rn < N) + tl.atomic_add(Y_ptr, acc0, mask=rn < N, sem="relaxed") -from torch.library import triton_op, wrap_triton # NOTE: assumes that weight is column major @triton_op("torchao::splitk_sparse_gemv", mutates_args={}) @@ -100,7 +102,7 @@ def splitk_sparse_gemv( N, Z = weight.shape beam_width, seq_len, _ = x.shape assert x.shape[2] == Z - x = x.contiguous() + assert x.is_contiguous() assert weight.stride(1) > 1, "weight should be column major" @@ -108,7 +110,7 @@ def splitk_sparse_gemv( grid = lambda META: ( triton.cdiv(N, META["BLOCK_N"]), triton.cdiv(Z, META["BLOCK_M"]), - ) # noqa + ) output = torch.empty( beam_width, diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 24c0808a02..06656e032f 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -47,6 +47,27 @@ def create_semi_structured_tensor(r, c, dtype): return sparse_weight.to(dtype) +def create_binary_tensor(shape, percent_zeros): + """ + Creates a PyTorch tensor with a specific percentage of zeros and ones. + + Args: + shape (tuple): The shape of the tensor to create + percent_zeros (float): Percentage of zeros in the tensor (between 0 and 1) + + Returns: + torch.Tensor: A tensor with specified percentage of zeros and ones + """ + total_elements = torch.prod(torch.tensor(shape)).item() + num_zeros = int(total_elements * percent_zeros) + tensor = torch.ones(total_elements) + zero_indices = torch.randperm(total_elements)[:num_zeros] + tensor[zero_indices] = 0 + tensor = tensor.reshape(shape) + + return tensor + + # Observers class PerChannelNormObserver(UniformQuantizationObserverBase): """ From a5ec96e31230ddb1bebbe17e931683ebb4b80512 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 19 May 2025 07:28:06 -0700 Subject: [PATCH 03/25] cleanup --- benchmarks/benchmark_splitk_sparse_gemv.py | 3 ++- test/sparsity/test_activation24.py | 10 +++++++--- torchao/kernel/splitk_sparse_gemv.py | 13 +++++-------- .../prototype/sparsity/activation/srelu_linear.py | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py index 27a08df987..bc2a0dc6aa 100644 --- a/benchmarks/benchmark_splitk_sparse_gemv.py +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -3,6 +3,7 @@ from triton.testing import do_bench from torchao.sparsity.utils import create_binary_tensor from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv +import torch.nn.functional as F for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: @@ -11,6 +12,6 @@ b = torch.randn(16384, 4096).cuda().to(torch.float16).T.contiguous().T sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 - dense_time = do_bench(lambda: torch.matmul(a, b.T)) * 1e6 + dense_time = do_bench(lambda: F.linear(a, b)) * 1e6 speedup = dense_time / sparse_time print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 1ab92a1aa1..96ac8d1d2f 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -146,9 +146,12 @@ def srelu_linear(x): def test_splitk_sparse_gemv(): torch.manual_seed(0) + # print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction) + # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction) - activation = create_binary_tensor((1, 1, 1024), 0.5).cuda().to(torch.float16) - weight = torch.randn(1024, 1024, dtype=torch.float16).cuda() + activation = create_binary_tensor((1, 1, 4096), 0.2).cuda().to(torch.float16) + weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() # weight must be column major weight_transposed = weight.T.contiguous().T @@ -156,4 +159,5 @@ def test_splitk_sparse_gemv(): sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) dense_res = F.linear(activation, weight_transposed) - torch.testing.assert_close(sparse_res, dense_res, rtol=0.1, atol=0.01) + # This rtol is very high, due to + torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py index 897c92ae83..c311753385 100644 --- a/torchao/kernel/splitk_sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -19,7 +19,6 @@ def init_to_zero(*args, **kwargs): configs=[ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), @@ -31,13 +30,11 @@ def init_to_zero(*args, **kwargs): triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), - # # Llama 3 variants can use BLOCK_N >= 1024 triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), @@ -74,16 +71,16 @@ def splitk_sparse_gemv_kernel( # eviction policy go brrr if BATCHSIZE == 1: - x0 = tl.load(X_ptr, mask=rm < M, other=0, eviction_policy='evict_last') # reuse x across threadblocks + x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks idx = (x0 != 0) # selectively load weight rows - a = tl.load(A_ptr, mask=idx[:, None], other=0, eviction_policy='evict_first') # only load weights once per threadblock - acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], 0) + a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock + acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) # rematerialize rm and rn to save registers - rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + # rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - tl.atomic_add(Y_ptr, acc0, mask=rn < N, sem="relaxed") + tl.atomic_add(Y_ptr, acc0, mask=rn < N) diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index 51846c9933..4fdb6a6919 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -13,7 +13,7 @@ from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.kernel.sparse_gemv.sparse_gemv import splitk_sparse_gemv +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv @dataclass From 8bc4dc4eaebb67a74ced5502a43a6065306e754c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 19 May 2025 09:45:34 -0700 Subject: [PATCH 04/25] ruff format --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 84 ++++++++----------- benchmarks/benchmark_splitk_sparse_gemv.py | 21 +++-- e2e_fp8_sparse.csv | 8 -- test/sparsity/test_activation24.py | 40 +++++++-- torchao/kernel/splitk_sparse_gemv.py | 65 +++++++------- .../sparsity/activation/srelu_linear.py | 1 - .../prototype/sparsity/activation/utils.py | 4 +- 7 files changed, 122 insertions(+), 101 deletions(-) delete mode 100644 e2e_fp8_sparse.csv diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index fbab8c0671..a37c370d80 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy + import pandas as pd import torch from torch import nn @@ -10,7 +12,7 @@ from triton.testing import do_bench from torchao.prototype.sparsity.activation.srelu_linear import ( - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, + ActivationSparseLinearConfig, ) from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( @@ -20,16 +22,23 @@ PerRow, quantize_, ) +from torchao.sparsity.utils import create_binary_tensor def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 -def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): +def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): + + target_sparsity_output = create_binary_tensor((1, num_tokens, intermediate_size), 0.9).cuda().to(torch.bfloat16) + target_sparsity_output = torch.randn(num_tokens, intermediate_size).cuda().to(torch.bfloat16) * target_sparsity_output + print(target_sparsity_output) + + ffn_ref = ( nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), + # nn.Linear(hidden_size, intermediate_size, bias=False), SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) @@ -37,32 +46,22 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): .cuda() ) - input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda() + + # input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda() + input_tensor = target_sparsity_output + # ffn_ref[0].weight.data = torch.linalg.solve(input_tensor, target_sparsity_output).T + # ffn_ref[0].weight.data = torch.load("/data/users/jessecai/ao/checkpoints/meta-llama/ffn_up.pt") + fp16_time = benchmark_microseconds(ffn_ref, input_tensor) + # breakpoint() # bf16 - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - SquaredReLU(), - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) + ffn_clone = copy.deepcopy(ffn_ref) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp16_c_time = benchmark_microseconds(ffn_clone, input_tensor) # fp8 - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - SquaredReLU(), - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) + ffn_clone = copy.deepcopy(ffn_ref) quantize_( ffn_clone, Float8DynamicActivationFloat8WeightConfig( @@ -73,38 +72,27 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor) # fp8 sparse - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - SquaredReLU(), - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) + ffn_clone = copy.deepcopy(ffn_ref) quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) # activation fp8 sparse - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - # no Squared RELU since it will be fused into the second linear - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) - quantize_( - ffn_clone[0], - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) - ), - ) + ffn_clone = copy.deepcopy(ffn_ref) + # quantize_( + # ffn_clone[0], + # Float8DynamicActivationFloat8WeightConfig( + # granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + # ), + # ) + # quantize_( + # ffn_clone, + # SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), + # filter_fn=lambda mod, fqn: "1" in fqn, + # ) quantize_( ffn_clone, - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), + ActivationSparseLinearConfig(), filter_fn=lambda mod, fqn: "1" in fqn, ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) @@ -124,7 +112,7 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): if __name__ == "__main__": with torch.no_grad(): results = [] - for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): + for num_tokens in tqdm([1]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py index bc2a0dc6aa..c5f8117f8f 100644 --- a/benchmarks/benchmark_splitk_sparse_gemv.py +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -1,17 +1,24 @@ import torch - from triton.testing import do_bench -from torchao.sparsity.utils import create_binary_tensor + from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv -import torch.nn.functional as F +from torchao.sparsity.utils import create_binary_tensor + +dtype = torch.float8_e4m3fn for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: - a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(torch.float16) - b = torch.randn(16384, 4096).cuda().to(torch.float16).T.contiguous().T + a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(dtype) + b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T + + sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b, out_dtype=torch.bfloat16)) * 1e6 - sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 - dense_time = do_bench(lambda: F.linear(a, b)) * 1e6 + # dense_time = do_bench(lambda: F.linear(a, b)) * 1e6 + b = torch.randn(4096, 16384).cuda().to(dtype).T.contiguous().T + dense_time = do_bench(lambda: torch._scaled_mm(a.squeeze(0), b, + scale_a=torch.Tensor([1]).cuda(), + scale_b=torch.Tensor([1]).cuda(), + out_dtype=torch.bfloat16)) * 1e6 speedup = dense_time / sparse_time print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/e2e_fp8_sparse.csv b/e2e_fp8_sparse.csv deleted file mode 100644 index 05a80e13b7..0000000000 --- a/e2e_fp8_sparse.csv +++ /dev/null @@ -1,8 +0,0 @@ -num_tokens,bf16_latency (us),bf16_c_latency (us),fp8_c_time (us),fp8_c_sparse_time (us),fp8_c_activation_sparse_time (us),speedup -64,166.81599617004395,163.03999722003937,103.00800204277039,74.30399954319,102.81600058078766,1.0018674278409796 -128,156.25600516796112,151.5199989080429,99.93600100278854,75.45600086450577,102.04800218343735,0.9793038458817415 -256,172.28800058364868,159.58400070667267,114.07999694347382,82.43200182914734,111.07199639081955,1.0270815385551393 -512,218.87999773025513,204.6079933643341,144.0960019826889,114.56000059843063,139.48799669742584,1.0330351384661336 -1024,394.4000005722046,392.5440013408661,251.10399723052979,196.4160054922104,227.90400683879852,1.1017972027501084 -2048,764.6080255508423,734.8160147666931,480.70400953292847,381.1520040035248,426.68798565864563,1.1265937305239622 -4096,1658.8159799575806,1623.5840320587158,901.3440012931824,779.0079712867737,843.392014503479,1.0687129896811043 diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 96ac8d1d2f..1fe3c2b3d2 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -18,7 +18,7 @@ SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, ) from torchao.sparsity import sparsify_ -from torchao.sparsity.utils import create_semi_structured_tensor, create_binary_tensor +from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor from torchao.utils import is_sm_at_least_90 @@ -146,9 +146,6 @@ def srelu_linear(x): def test_splitk_sparse_gemv(): torch.manual_seed(0) - # print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction) - # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction) activation = create_binary_tensor((1, 1, 4096), 0.2).cuda().to(torch.float16) weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() @@ -159,5 +156,38 @@ def test_splitk_sparse_gemv(): sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) dense_res = F.linear(activation, weight_transposed) - # This rtol is very high, due to + # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) + + +# baseline feather est. speed input: 26053.69 toks/s, output: 3087.22 toks/s + + +## PHI int4 +# Avg latency: 0.9666175044667095 seconds +# 10% percentile latency: 0.9595451743996819 seconds +# 25% percentile latency: 0.9622359074999167 seconds +# 50% percentile latency: 0.9658922594994692 seconds +# 75% percentile latency: 0.9698955072499302 seconds +# 90% percentile latency: 0.973863469400294 seconds +# 99% percentile latency: 0.9819546566300688 seconds + +## PHI baseline +# Avg latency: 1.5343882635333708 seconds +# 10% percentile latency: 1.521387348999906 seconds +# 25% percentile latency: 1.5298032142500233 seconds +# 50% percentile latency: 1.536684918499759 seconds +# 75% percentile latency: 1.5403528439999263 seconds +# 90% percentile latency: 1.5427267418004704 seconds +# 99% percentile latency: 1.5486474337500475 seconds + +# Feather baseline: +# Avg latency: 2.43545591363345 seconds +# 10% percentile latency: 2.419638815799408 seconds +# 25% percentile latency: 2.423865938751078 seconds +# 50% percentile latency: 2.435959159500271 seconds +# 75% percentile latency: 2.447552447250473 seconds +# 90% percentile latency: 2.4524878560003343 seconds +# 99% percentile latency: 2.4566773237699637 seconds + +# Feather fp8: diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py index c311753385..3e7f36dbd8 100644 --- a/torchao/kernel/splitk_sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -3,7 +3,8 @@ Since we already have sparse activations from ReLU, we can get rid of the thresholding step and just use the sparse tensor directly. """ - +import sys +import warnings from typing import Optional import torch @@ -11,40 +12,39 @@ import triton.language as tl from torch.library import triton_op, wrap_triton -def init_to_zero(*args, **kwargs): - # print(type) - args[0]["Y"].zero_() - -# NOTE: will need to warm up kernels each time, triton autotune caching isn't a thing right now +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4), + triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4), # # Llama 3 variants can use BLOCK_N >= 1024 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4), ] @triton.autotune( configs=configs, key=["CACHE_KEY_M", "CACHE_KEY_N", "BATCHSIZE"], + reset_to_zero=["Y"], # reset the content of Y to zero before computation ) @triton.jit def splitk_sparse_gemv_kernel( @@ -78,8 +78,7 @@ def splitk_sparse_gemv_kernel( acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) # rematerialize rm and rn to save registers - # rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) tl.atomic_add(Y_ptr, acc0, mask=rn < N) @@ -89,6 +88,7 @@ def splitk_sparse_gemv_kernel( def splitk_sparse_gemv( x: torch.Tensor, weight: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Compute y = sparse(X) @ weight. @@ -131,8 +131,11 @@ def splitk_sparse_gemv( # can't use kwargs because auto-tuner requires args ) - if x.dtype is not output.dtype: - print(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") - return output.to(dtype=x.dtype) + # if x.dtype is not output.dtype: + # warnings.warn(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") + # return output.to(dtype=x.dtype) + + if out_dtype: + return output.to(dtype=out_dtype) return output diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index 4fdb6a6919..318168edf7 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -13,7 +13,6 @@ from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv @dataclass diff --git a/torchao/prototype/sparsity/activation/utils.py b/torchao/prototype/sparsity/activation/utils.py index 696649b18c..6fef483ea0 100644 --- a/torchao/prototype/sparsity/activation/utils.py +++ b/torchao/prototype/sparsity/activation/utils.py @@ -69,7 +69,9 @@ def __init__(self): super().__init__() def forward(self, x): - return F.relu(x) ** 2 + res = F.relu(x) ** 2 + # print((res==0).sum() / res.numel()) + return res def profiler_runner(path, fn, *args, **kwargs): From d49492e42b757a1bd51f7c7a6a678b70c6bba23b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 19 May 2025 10:00:52 -0700 Subject: [PATCH 05/25] cleanup --- test/sparsity/test_activation24.py | 47 ++++++++++-------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 1fe3c2b3d2..9c7673d344 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -160,34 +160,19 @@ def test_splitk_sparse_gemv(): torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) -# baseline feather est. speed input: 26053.69 toks/s, output: 3087.22 toks/s - - -## PHI int4 -# Avg latency: 0.9666175044667095 seconds -# 10% percentile latency: 0.9595451743996819 seconds -# 25% percentile latency: 0.9622359074999167 seconds -# 50% percentile latency: 0.9658922594994692 seconds -# 75% percentile latency: 0.9698955072499302 seconds -# 90% percentile latency: 0.973863469400294 seconds -# 99% percentile latency: 0.9819546566300688 seconds - -## PHI baseline -# Avg latency: 1.5343882635333708 seconds -# 10% percentile latency: 1.521387348999906 seconds -# 25% percentile latency: 1.5298032142500233 seconds -# 50% percentile latency: 1.536684918499759 seconds -# 75% percentile latency: 1.5403528439999263 seconds -# 90% percentile latency: 1.5427267418004704 seconds -# 99% percentile latency: 1.5486474337500475 seconds - -# Feather baseline: -# Avg latency: 2.43545591363345 seconds -# 10% percentile latency: 2.419638815799408 seconds -# 25% percentile latency: 2.423865938751078 seconds -# 50% percentile latency: 2.435959159500271 seconds -# 75% percentile latency: 2.447552447250473 seconds -# 90% percentile latency: 2.4524878560003343 seconds -# 99% percentile latency: 2.4566773237699637 seconds - -# Feather fp8: +def test_splitk_sparse_gemv_fp8(): + + torch.nn.Linear() + torch.manual_seed(0) + + activation = create_binary_tensor((1, 1, 4096), 0.2).cuda().to(torch.float16) + weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() + + # weight must be column major + weight_transposed = weight.T.contiguous().T + + sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) + dense_res = F.linear(activation, weight_transposed) + + # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. + torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) From cb8c629d294bcf5e50c2ffbf761a8f151b0ad9ba Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 21 May 2025 17:35:27 -0700 Subject: [PATCH 06/25] squarerelu working --- benchmarks/benchmark_splitk_sparse_gemv.py | 16 +- torchao/kernel/splitk_sparse_gemv.py | 33 ++-- .../sparsity/activation/srelu_linear.py | 48 ------ torchao/quantization/quant_api.py | 4 + torchao/sparsity/sparse_api.py | 154 ++++++++++++++++++ 5 files changed, 182 insertions(+), 73 deletions(-) diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py index c5f8117f8f..61ccb022d1 100644 --- a/benchmarks/benchmark_splitk_sparse_gemv.py +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -4,21 +4,23 @@ from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor +import torch.nn.functional as F + dtype = torch.float8_e4m3fn for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: - a = create_binary_tensor((1, 1, 4096), sparsity_level).cuda().to(dtype) + a = create_binary_tensor((1, 4096), sparsity_level).cuda().to(dtype) b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b, out_dtype=torch.bfloat16)) * 1e6 - # dense_time = do_bench(lambda: F.linear(a, b)) * 1e6 - b = torch.randn(4096, 16384).cuda().to(dtype).T.contiguous().T - dense_time = do_bench(lambda: torch._scaled_mm(a.squeeze(0), b, - scale_a=torch.Tensor([1]).cuda(), - scale_b=torch.Tensor([1]).cuda(), - out_dtype=torch.bfloat16)) * 1e6 + dense_time = do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6 + # b = torch.randn(4096, 16384).cuda().to(dtype).T.contiguous().T + # dense_time = do_bench(lambda: torch._scaled_mm(a.squeeze(0), b, + # scale_a=torch.Tensor([1]).cuda(), + # scale_b=torch.Tensor([1]).cuda(), + # out_dtype=torch.bfloat16)) * 1e6 speedup = dense_time / sparse_time print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py index 3e7f36dbd8..65f6c645e5 100644 --- a/torchao/kernel/splitk_sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -43,7 +43,7 @@ @triton.autotune( configs=configs, - key=["CACHE_KEY_M", "CACHE_KEY_N", "BATCHSIZE"], + key=["CACHE_KEY_M", "CACHE_KEY_N"], reset_to_zero=["Y"], # reset the content of Y to zero before computation ) @triton.jit @@ -54,7 +54,6 @@ def splitk_sparse_gemv_kernel( N, M, CACHE_KEY_N, CACHE_KEY_M, # Meta-parameters - BATCHSIZE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, ): start_n = tl.program_id(0) @@ -70,15 +69,15 @@ def splitk_sparse_gemv_kernel( Y_ptr = Y + rn # eviction policy go brrr - if BATCHSIZE == 1: - x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks - idx = (x0 != 0) - # selectively load weight rows - a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock - acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) + x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks + idx = (x0 != 0.0) + # selectively load weight rows + a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock + acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) # rematerialize rm and rn to save registers rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + # TODO atomic add supports bfloat16 in latest triton, we should update to that tl.atomic_add(Y_ptr, acc0, mask=rn < N) @@ -97,8 +96,8 @@ def splitk_sparse_gemv( :return: result tensor y """ N, Z = weight.shape - beam_width, seq_len, _ = x.shape - assert x.shape[2] == Z + seq_len, _ = x.shape + assert x.shape[-1] == Z assert x.is_contiguous() assert weight.stride(1) > 1, "weight should be column major" @@ -109,8 +108,7 @@ def splitk_sparse_gemv( triton.cdiv(Z, META["BLOCK_M"]), ) - output = torch.empty( - beam_width, + output = torch.zeros( seq_len, N, device=x.device, @@ -127,15 +125,14 @@ def splitk_sparse_gemv( Z, N // 16, # key for triton cache (limit number of compilations) Z // 16, - beam_width, # BATCHSIZE # can't use kwargs because auto-tuner requires args ) - # if x.dtype is not output.dtype: - # warnings.warn(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") - # return output.to(dtype=x.dtype) + if x.dtype is not output.dtype: + # warnings.warn(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") + return output.to(dtype=x.dtype) - if out_dtype: - return output.to(dtype=out_dtype) + # if out_dtype: + # return output.to(dtype=out_dtype) return output diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index 318168edf7..f8c3288b67 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -85,51 +85,3 @@ def from_dense( raise NotImplementedError("weight dtype must be bf16") return cls(linear.weight.data, config) - - -@dataclass -class ActivationSparseLinearConfig(AOBaseConfig): - """ - Adds in acceleration for activation sparsity to linear layers for decode. - - Args: - `activation_dtype`: data type for quantized activation tensor. - `weight_dtype`: data type for quantized weight tensor. - """ - - activation_dtype: torch.dtype = torch.float8_e4m3fn - weight_dtype: torch.dtype = torch.float8_e4m3fn - -@register_quantize_module_handler( - ActivationSparseLinearConfig -) -def _activation_spare_linear_transform( - module: torch.nn.Module, - config: ActivationSparseLinearConfig, -): - return ActivationSparseLinear.from_dense(module, config) - - -class ActivationSparseLinear(nn.Module): - """ - Replacement nn.Linear that supports runtime fp8 activation sparsity - """ - - def __init__(self, weight, config) -> None: - super().__init__() - self.config = config - self.weight_transposed = weight.T.contiguous().T - - def forward(self, x): - if x.shape[1] == 1: - return torch.ops.torchao.splitk_sparse_gemv(x, self.weight_transposed) - else: - return torch.nn.functional.linear(x, self.weight_transposed) - - @classmethod - def from_dense( - cls, linear, config:ActivationSparseLinearConfig - ): - if linear.bias is not None: - raise NotImplementedError("bias is not supported") - return cls(linear.weight.data, config) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 982b8cdd5c..7f7367f9fb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -69,6 +69,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TorchAOBaseTensor, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -530,6 +531,9 @@ def _quantization_type(weight: torch.Tensor): if type(weight) is torch.Tensor: return "not quantized" + if isinstance(weight, TorchAOBaseTensor): + return f"{weight.__class__.__name__}" + return "not recognized" diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index b263b5e098..1151ed437a 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -24,6 +24,24 @@ register_quantize_module_handler, ) from torchao.sparsity.blocksparse import BlockSparseTensor +from dataclasses import dataclass + +import torch +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, +) +from torchao.quantization.quant_api import ( + _float8_cutlass_quant, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv +from torch.utils._python_dispatch import return_and_correct_aliasing # Sparsity helper functions @@ -134,3 +152,139 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _is_linear if filter_fn is None else filter_fn, extra_args=(config,), ) + + +from torchao.utils import TorchAOBaseTensor + +class ActivationSparseTensor(TorchAOBaseTensor): + data: Optional[torch.Tensor] + + __slots__ = ["data"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + data: Optional[torch.Tensor], + requires_grad: bool = False, + ): + assert data is not None + kwargs = { + "device": data.device, + "dtype": data.dtype, + "layout": data.layout, + "requires_grad": requires_grad, + } + tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.data = data + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta, + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, requires_grad = tensor_meta + return cls( + shape=shape, + data=inner_tensors.get("data", None), + requires_grad=requires_grad, + ) + + @classmethod + def from_dense(cls, weight): + return cls(weight.shape, + weight.data.t().contiguous().t(), + requires_grad=False) + + def apply_fn_to_shard(self, func): + return ActivationSparseTensor( + shape=self.shape, + data=func(self.data), + requires_grad=self.requires_grad, + ) + +# Subclass op dispatch registration +implements = ActivationSparseTensor.implements +aten = torch.ops.aten + + +@implements( + [ + aten.detach.default, + aten.slice.Tensor, + ] +) +def _(func, types, args, kwargs): + new_data = func(args[0].data, *args[1:], **kwargs) + return ActivationSparseTensor( + new_data.shape, + data=new_data, + requires_grad=False, + ) + +@implements( + [aten.copy_.default] +) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + self.data.copy_(src.data) + return + +@implements(torch.nn.functional.linear) +def sparse_activation_linear(func, types, args, kwargs): + x_orig, w, bias = args + assert bias is None + x = x_orig.view(-1, x_orig.size(-1)) + # M = w.shape[0] + # K = w.shape[1] + + if x.shape[0] == 1: + x_relu = torch.square(torch.nn.functional.relu(x)) + res = torch.ops.torchao.splitk_sparse_gemv(x_relu, + w.data) + return res.view(*x_orig.shape[:-1], w.shape[0]) + else: + x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) + return torch.nn.functional.linear(x_orig_relu, w.data, bias) + + +@dataclass +class ActivationSparseLinearConfig(AOBaseConfig): + """ + Adds in acceleration for activation sparsity to linear layers for decode. + + Args: + `activation_dtype`: data type for quantized activation tensor. + `weight_dtype`: data type for quantized weight tensor. + """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + +@register_quantize_module_handler( + ActivationSparseLinearConfig +) +def _( + module: torch.nn.Module, + config: ActivationSparseLinearConfig, +): + new_weight = ActivationSparseTensor.from_dense(module.weight.data) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module From e6860c61b30bc6d130aad44dd6e5d1a5cbb34607 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 21 May 2025 17:41:16 -0700 Subject: [PATCH 07/25] restore file --- e2e_fp8_sparse.csv | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 e2e_fp8_sparse.csv diff --git a/e2e_fp8_sparse.csv b/e2e_fp8_sparse.csv new file mode 100644 index 0000000000..05a80e13b7 --- /dev/null +++ b/e2e_fp8_sparse.csv @@ -0,0 +1,8 @@ +num_tokens,bf16_latency (us),bf16_c_latency (us),fp8_c_time (us),fp8_c_sparse_time (us),fp8_c_activation_sparse_time (us),speedup +64,166.81599617004395,163.03999722003937,103.00800204277039,74.30399954319,102.81600058078766,1.0018674278409796 +128,156.25600516796112,151.5199989080429,99.93600100278854,75.45600086450577,102.04800218343735,0.9793038458817415 +256,172.28800058364868,159.58400070667267,114.07999694347382,82.43200182914734,111.07199639081955,1.0270815385551393 +512,218.87999773025513,204.6079933643341,144.0960019826889,114.56000059843063,139.48799669742584,1.0330351384661336 +1024,394.4000005722046,392.5440013408661,251.10399723052979,196.4160054922104,227.90400683879852,1.1017972027501084 +2048,764.6080255508423,734.8160147666931,480.70400953292847,381.1520040035248,426.68798565864563,1.1265937305239622 +4096,1658.8159799575806,1623.5840320587158,901.3440012931824,779.0079712867737,843.392014503479,1.0687129896811043 From 3248e79d11211d55ed1173b36737742adb2a86df Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 21 May 2025 17:47:15 -0700 Subject: [PATCH 08/25] cleanup --- torchao/prototype/sparsity/activation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/prototype/sparsity/activation/utils.py b/torchao/prototype/sparsity/activation/utils.py index 6fef483ea0..696649b18c 100644 --- a/torchao/prototype/sparsity/activation/utils.py +++ b/torchao/prototype/sparsity/activation/utils.py @@ -69,9 +69,7 @@ def __init__(self): super().__init__() def forward(self, x): - res = F.relu(x) ** 2 - # print((res==0).sum() / res.numel()) - return res + return F.relu(x) ** 2 def profiler_runner(path, fn, *args, **kwargs): From ddc1d9a32aacab1b4010657597c8d2842f205475 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 06:07:29 -0700 Subject: [PATCH 09/25] ruff format sparse_api --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 84 +++++++++++-------- torchao/kernel/splitk_sparse_gemv.py | 5 -- torchao/sparsity/sparse_api.py | 23 +++-- 3 files changed, 59 insertions(+), 53 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index a37c370d80..fbab8c0671 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -3,8 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import copy - import pandas as pd import torch from torch import nn @@ -12,7 +10,7 @@ from triton.testing import do_bench from torchao.prototype.sparsity.activation.srelu_linear import ( - ActivationSparseLinearConfig, + SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, ) from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( @@ -22,23 +20,16 @@ PerRow, quantize_, ) -from torchao.sparsity.utils import create_binary_tensor def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 -def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): - - target_sparsity_output = create_binary_tensor((1, num_tokens, intermediate_size), 0.9).cuda().to(torch.bfloat16) - target_sparsity_output = torch.randn(num_tokens, intermediate_size).cuda().to(torch.bfloat16) * target_sparsity_output - print(target_sparsity_output) - - +def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): ffn_ref = ( nn.Sequential( - # nn.Linear(hidden_size, intermediate_size, bias=False), + nn.Linear(hidden_size, intermediate_size, bias=False), SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) @@ -46,22 +37,32 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): .cuda() ) - - # input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda() - input_tensor = target_sparsity_output - # ffn_ref[0].weight.data = torch.linalg.solve(input_tensor, target_sparsity_output).T - # ffn_ref[0].weight.data = torch.load("/data/users/jessecai/ao/checkpoints/meta-llama/ffn_up.pt") - + input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda() fp16_time = benchmark_microseconds(ffn_ref, input_tensor) - # breakpoint() # bf16 - ffn_clone = copy.deepcopy(ffn_ref) + ffn_clone = ( + nn.Sequential( + nn.Linear(hidden_size, intermediate_size, bias=False), + SquaredReLU(), + nn.Linear(intermediate_size, hidden_size, bias=False), + ) + .to(torch.bfloat16) + .cuda() + ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp16_c_time = benchmark_microseconds(ffn_clone, input_tensor) # fp8 - ffn_clone = copy.deepcopy(ffn_ref) + ffn_clone = ( + nn.Sequential( + nn.Linear(hidden_size, intermediate_size, bias=False), + SquaredReLU(), + nn.Linear(intermediate_size, hidden_size, bias=False), + ) + .to(torch.bfloat16) + .cuda() + ) quantize_( ffn_clone, Float8DynamicActivationFloat8WeightConfig( @@ -72,27 +73,38 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor) # fp8 sparse - ffn_clone = copy.deepcopy(ffn_ref) + ffn_clone = ( + nn.Sequential( + nn.Linear(hidden_size, intermediate_size, bias=False), + SquaredReLU(), + nn.Linear(intermediate_size, hidden_size, bias=False), + ) + .to(torch.bfloat16) + .cuda() + ) quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) # activation fp8 sparse - ffn_clone = copy.deepcopy(ffn_ref) - # quantize_( - # ffn_clone[0], - # Float8DynamicActivationFloat8WeightConfig( - # granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) - # ), - # ) - # quantize_( - # ffn_clone, - # SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), - # filter_fn=lambda mod, fqn: "1" in fqn, - # ) + ffn_clone = ( + nn.Sequential( + nn.Linear(hidden_size, intermediate_size, bias=False), + # no Squared RELU since it will be fused into the second linear + nn.Linear(intermediate_size, hidden_size, bias=False), + ) + .to(torch.bfloat16) + .cuda() + ) + quantize_( + ffn_clone[0], + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), + ) quantize_( ffn_clone, - ActivationSparseLinearConfig(), + SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), filter_fn=lambda mod, fqn: "1" in fqn, ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) @@ -112,7 +124,7 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): if __name__ == "__main__": with torch.no_grad(): results = [] - for num_tokens in tqdm([1]): + for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py index 65f6c645e5..7e9c4d94fa 100644 --- a/torchao/kernel/splitk_sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -87,7 +87,6 @@ def splitk_sparse_gemv_kernel( def splitk_sparse_gemv( x: torch.Tensor, weight: torch.Tensor, - out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Compute y = sparse(X) @ weight. @@ -129,10 +128,6 @@ def splitk_sparse_gemv( ) if x.dtype is not output.dtype: - # warnings.warn(f"Warning: incuring dtype conversion overhead since input dtype is not torch.float16. Detected dtype: {x.dtype}. ") return output.to(dtype=x.dtype) - # if out_dtype: - # return output.to(dtype=out_dtype) - return output diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 1151ed437a..2fa7fb2fc3 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -156,6 +156,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: from torchao.utils import TorchAOBaseTensor + class ActivationSparseTensor(TorchAOBaseTensor): data: Optional[torch.Tensor] @@ -207,9 +208,7 @@ def __tensor_unflatten__( @classmethod def from_dense(cls, weight): - return cls(weight.shape, - weight.data.t().contiguous().t(), - requires_grad=False) + return cls(weight.shape, weight.data.t().contiguous().t(), requires_grad=False) def apply_fn_to_shard(self, func): return ActivationSparseTensor( @@ -218,6 +217,7 @@ def apply_fn_to_shard(self, func): requires_grad=self.requires_grad, ) + # Subclass op dispatch registration implements = ActivationSparseTensor.implements aten = torch.ops.aten @@ -237,18 +237,19 @@ def _(func, types, args, kwargs): requires_grad=False, ) -@implements( - [aten.copy_.default] -) + +@implements([aten.copy_.default]) def _(func, types, args, kwargs): self = args[0] src = args[1] self.data.copy_(src.data) return + @implements(torch.nn.functional.linear) def sparse_activation_linear(func, types, args, kwargs): x_orig, w, bias = args + print(x_orig.shape) assert bias is None x = x_orig.view(-1, x_orig.size(-1)) # M = w.shape[0] @@ -256,8 +257,7 @@ def sparse_activation_linear(func, types, args, kwargs): if x.shape[0] == 1: x_relu = torch.square(torch.nn.functional.relu(x)) - res = torch.ops.torchao.splitk_sparse_gemv(x_relu, - w.data) + res = torch.ops.torchao.splitk_sparse_gemv(x_relu, w.data) return res.view(*x_orig.shape[:-1], w.shape[0]) else: x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) @@ -267,7 +267,7 @@ def sparse_activation_linear(func, types, args, kwargs): @dataclass class ActivationSparseLinearConfig(AOBaseConfig): """ - Adds in acceleration for activation sparsity to linear layers for decode. + Adds in acceleration for activation sparsity to linear layers for decode. Args: `activation_dtype`: data type for quantized activation tensor. @@ -277,9 +277,8 @@ class ActivationSparseLinearConfig(AOBaseConfig): activation_dtype: torch.dtype = torch.float8_e4m3fn weight_dtype: torch.dtype = torch.float8_e4m3fn -@register_quantize_module_handler( - ActivationSparseLinearConfig -) + +@register_quantize_module_handler(ActivationSparseLinearConfig) def _( module: torch.nn.Module, config: ActivationSparseLinearConfig, From 55980daf73bad85b6da14f149c2b99843aa42efe Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 06:08:37 -0700 Subject: [PATCH 10/25] ruff --- test/sparsity/test_activation24.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 9c7673d344..57d8a6aedb 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -143,7 +143,6 @@ def srelu_linear(x): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) - def test_splitk_sparse_gemv(): torch.manual_seed(0) @@ -156,23 +155,5 @@ def test_splitk_sparse_gemv(): sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) dense_res = F.linear(activation, weight_transposed) - # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. - torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) - - -def test_splitk_sparse_gemv_fp8(): - - torch.nn.Linear() - torch.manual_seed(0) - - activation = create_binary_tensor((1, 1, 4096), 0.2).cuda().to(torch.float16) - weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() - - # weight must be column major - weight_transposed = weight.T.contiguous().T - - sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) - dense_res = F.linear(activation, weight_transposed) - - # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. + # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1) From dd11ddb37932082f6c4b674dde9a540aaa86fd28 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 06:10:18 -0700 Subject: [PATCH 11/25] ruff one more time --- torchao/sparsity/sparse_api.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 2fa7fb2fc3..9534a86dd0 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -24,24 +24,6 @@ register_quantize_module_handler, ) from torchao.sparsity.blocksparse import BlockSparseTensor -from dataclasses import dataclass - -import torch -from torch import nn - -from torchao.core.config import AOBaseConfig -from torchao.ops import ( - rowwise_scaled_linear_sparse_cutlass_f8f8, -) -from torchao.quantization.quant_api import ( - _float8_cutlass_quant, -) -from torchao.quantization.transform_module import ( - register_quantize_module_handler, -) - -from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv -from torch.utils._python_dispatch import return_and_correct_aliasing # Sparsity helper functions From a545b3e4b1041934e696264e6018475ae7728609 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 06:12:27 -0700 Subject: [PATCH 12/25] ruff --- benchmarks/benchmark_splitk_sparse_gemv.py | 16 +++++---- torchao/kernel/splitk_sparse_gemv.py | 42 ++++++++++++---------- torchao/sparsity/utils.py | 4 +-- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py index 61ccb022d1..de895ca962 100644 --- a/benchmarks/benchmark_splitk_sparse_gemv.py +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -1,26 +1,30 @@ import torch +import torch.nn.functional as F from triton.testing import do_bench from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor -import torch.nn.functional as F - dtype = torch.float8_e4m3fn for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: - a = create_binary_tensor((1, 4096), sparsity_level).cuda().to(dtype) b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T - sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b, out_dtype=torch.bfloat16)) * 1e6 + sparse_time = ( + do_bench(lambda: splitk_sparse_gemv(a, b, out_dtype=torch.bfloat16)) * 1e6 + ) - dense_time = do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6 + dense_time = ( + do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6 + ) # b = torch.randn(4096, 16384).cuda().to(dtype).T.contiguous().T # dense_time = do_bench(lambda: torch._scaled_mm(a.squeeze(0), b, # scale_a=torch.Tensor([1]).cuda(), # scale_b=torch.Tensor([1]).cuda(), # out_dtype=torch.bfloat16)) * 1e6 speedup = dense_time / sparse_time - print(f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}") + print( + f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}" + ) diff --git a/torchao/kernel/splitk_sparse_gemv.py b/torchao/kernel/splitk_sparse_gemv.py index 7e9c4d94fa..a556bc4195 100644 --- a/torchao/kernel/splitk_sparse_gemv.py +++ b/torchao/kernel/splitk_sparse_gemv.py @@ -3,9 +3,9 @@ Since we already have sparse activations from ReLU, we can get rid of the thresholding step and just use the sparse tensor directly. """ + import sys import warnings -from typing import Optional import torch import triton @@ -16,8 +16,8 @@ # to suppress repeated warnings when being used in a training loop. warnings.simplefilter("once") -configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2), +configs = [ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4), triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}, num_warps=2), triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), @@ -33,7 +33,6 @@ triton.Config({"BLOCK_M": 64, "BLOCK_N": 512}, num_warps=4), triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=4), triton.Config({"BLOCK_M": 16, "BLOCK_N": 512}, num_warps=4), - # # Llama 3 variants can use BLOCK_N >= 1024 triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4), triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4), @@ -41,6 +40,7 @@ triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4), ] + @triton.autotune( configs=configs, key=["CACHE_KEY_M", "CACHE_KEY_N"], @@ -48,31 +48,39 @@ ) @triton.jit def splitk_sparse_gemv_kernel( - Y, # Pointers to matrices - A, X, + Y, # Pointers to matrices + A, + X, # Matrix dimensions - N, M, - CACHE_KEY_N, CACHE_KEY_M, + N, + M, + CACHE_KEY_N, + CACHE_KEY_M, # Meta-parameters - BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, ): start_n = tl.program_id(0) start_m = tl.program_id(1) # now compute the block that each program will go through # rn (resp. rm) denotes a range of indices for rows (resp. col) of A - + rn = start_n * BLOCK_N + tl.arange(0, BLOCK_N) rm = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - + A_ptr = A + (rm[:, None] * N + rn[None, :]) X_ptr = X + rm Y_ptr = Y + rn - + # eviction policy go brrr - x0 = tl.load(X_ptr, mask=rm < M, other=0.0, eviction_policy='evict_last') # reuse x across threadblocks - idx = (x0 != 0.0) + x0 = tl.load( + X_ptr, mask=rm < M, other=0.0, eviction_policy="evict_last" + ) # reuse x across threadblocks + idx = x0 != 0.0 # selectively load weight rows - a = tl.load(A_ptr, mask=idx[:, None], other=0.0, eviction_policy='evict_first') # only load weights once per threadblock + a = tl.load( + A_ptr, mask=idx[:, None], other=0.0, eviction_policy="evict_first" + ) # only load weights once per threadblock acc0 = tl.sum(a.to(tl.float32) * x0.to(tl.float32)[:, None], axis=0) # rematerialize rm and rn to save registers @@ -81,7 +89,6 @@ def splitk_sparse_gemv_kernel( tl.atomic_add(Y_ptr, acc0, mask=rn < N) - # NOTE: assumes that weight is column major @triton_op("torchao::splitk_sparse_gemv", mutates_args={}) def splitk_sparse_gemv( @@ -98,7 +105,7 @@ def splitk_sparse_gemv( seq_len, _ = x.shape assert x.shape[-1] == Z assert x.is_contiguous() - + assert weight.stride(1) > 1, "weight should be column major" # 1D launch kernel where each block gets its own program. @@ -114,7 +121,6 @@ def splitk_sparse_gemv( dtype=torch.float16, ) - kernel = wrap_triton(splitk_sparse_gemv_kernel) kernel[grid]( output, # data ptrs diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 06656e032f..4b6a19b183 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -50,11 +50,11 @@ def create_semi_structured_tensor(r, c, dtype): def create_binary_tensor(shape, percent_zeros): """ Creates a PyTorch tensor with a specific percentage of zeros and ones. - + Args: shape (tuple): The shape of the tensor to create percent_zeros (float): Percentage of zeros in the tensor (between 0 and 1) - + Returns: torch.Tensor: A tensor with specified percentage of zeros and ones """ From e099d786437d13e1eadc52e80fd9d1432f9ddb2d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 06:29:54 -0700 Subject: [PATCH 13/25] run tests only on cuda --- test/sparsity/test_activation24.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 57d8a6aedb..2e4957ff90 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -143,6 +143,7 @@ def srelu_linear(x): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) +@unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run") def test_splitk_sparse_gemv(): torch.manual_seed(0) From 747e0e1ffd690249692b6060a6b35ce443dfb75c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 08:45:41 -0700 Subject: [PATCH 14/25] check in changes before merge main --- torchao/sparsity/sparse_api.py | 101 ++++++++++++++++++++++++++++----- 1 file changed, 86 insertions(+), 15 deletions(-) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 9534a86dd0..aee4894b80 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -24,6 +24,24 @@ register_quantize_module_handler, ) from torchao.sparsity.blocksparse import BlockSparseTensor +from dataclasses import dataclass + +import torch +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, +) +from torchao.quantization.quant_api import ( + _float8_cutlass_quant, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv +from torch.utils._python_dispatch import return_and_correct_aliasing # Sparsity helper functions @@ -135,20 +153,26 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: extra_args=(config,), ) +def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() + x = (x / x_scale).to(dtype) + return x, x_scale from torchao.utils import TorchAOBaseTensor - class ActivationSparseTensor(TorchAOBaseTensor): data: Optional[torch.Tensor] + scale: Optional[torch.Tensor] - __slots__ = ["data"] + __slots__ = ["data", "scale"] @staticmethod def __new__( # noqa: PYI034 cls, shape: torch.Size, data: Optional[torch.Tensor], + scale: Optional[torch.Tensor], requires_grad: bool = False, ): assert data is not None @@ -160,6 +184,7 @@ def __new__( # noqa: PYI034 } tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] tensor.data = data + tensor.scale = scale return tensor def __repr__(self) -> str: # type: ignore[override] @@ -185,21 +210,32 @@ def __tensor_unflatten__( return cls( shape=shape, data=inner_tensors.get("data", None), + scale=inner_tensors.get("scale", None), requires_grad=requires_grad, ) @classmethod - def from_dense(cls, weight): - return cls(weight.shape, weight.data.t().contiguous().t(), requires_grad=False) + def from_dense(cls, weight, use_fp8=True): + if use_fp8: + weight, scale = _to_fp8_rowwise(weight, torch.float8_e4m3fn) + return cls(weight.shape, + data=weight.data.t().contiguous().t(), + scale=scale, + requires_grad=False) + else: + return cls(weight.shape, + data=weight.data.t().contiguous().t(), + scale=None, + requires_grad=False) def apply_fn_to_shard(self, func): return ActivationSparseTensor( shape=self.shape, data=func(self.data), + scale=func(self.scale), requires_grad=self.requires_grad, ) - # Subclass op dispatch registration implements = ActivationSparseTensor.implements aten = torch.ops.aten @@ -213,25 +249,30 @@ def apply_fn_to_shard(self, func): ) def _(func, types, args, kwargs): new_data = func(args[0].data, *args[1:], **kwargs) + if args[0].scale is None: + new_scale = None + else: + new_scale = func(args[0].scale, *args[1:], **kwargs) return ActivationSparseTensor( new_data.shape, data=new_data, + scale=new_scale, requires_grad=False, ) - -@implements([aten.copy_.default]) +@implements( + [aten.copy_.default] +) def _(func, types, args, kwargs): self = args[0] src = args[1] self.data.copy_(src.data) + self.scale.copy_(src.scale) return - @implements(torch.nn.functional.linear) def sparse_activation_linear(func, types, args, kwargs): x_orig, w, bias = args - print(x_orig.shape) assert bias is None x = x_orig.view(-1, x_orig.size(-1)) # M = w.shape[0] @@ -239,17 +280,47 @@ def sparse_activation_linear(func, types, args, kwargs): if x.shape[0] == 1: x_relu = torch.square(torch.nn.functional.relu(x)) - res = torch.ops.torchao.splitk_sparse_gemv(x_relu, w.data) + res = torch.ops.torchao.splitk_sparse_gemv(x_relu, + w.data) return res.view(*x_orig.shape[:-1], w.shape[0]) else: - x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) - return torch.nn.functional.linear(x_orig_relu, w.data, bias) + print(x.shape) + X_scale = torch.empty([x.shape[0], 1], dtype=torch.float32, device=x.device) + Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( + x, + "cutlass", + "srelu", + "largest", + dtype=torch.float8_e4m3fn, + scale=X_scale, + ) + X_scale_squeeze = X_scale.squeeze() + + breakpoint() + + result = rowwise_scaled_linear_sparse_cutlass_f8f8( + w.data, + w.scale.squeeze(), + Xq_sparse, + X_meta, + X_scale_squeeze, + bias=None, + out_dtype=torch.bfloat16, + ).t() + + breakpoint() + + return result + + # For normal linear + # x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) + # return torch.nn.functional.linear(x_orig_relu, w.data, bias) @dataclass class ActivationSparseLinearConfig(AOBaseConfig): """ - Adds in acceleration for activation sparsity to linear layers for decode. + Adds in acceleration for activation sparsity to linear layers for decode. Args: `activation_dtype`: data type for quantized activation tensor. @@ -259,8 +330,8 @@ class ActivationSparseLinearConfig(AOBaseConfig): activation_dtype: torch.dtype = torch.float8_e4m3fn weight_dtype: torch.dtype = torch.float8_e4m3fn - -@register_quantize_module_handler(ActivationSparseLinearConfig) +@register_quantize_module_handler( + ActivationSparseLinearConfig) def _( module: torch.nn.Module, config: ActivationSparseLinearConfig, From 2a68435e726f40b0992595d8006a8b0b53c89204 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 27 May 2025 10:38:53 -0700 Subject: [PATCH 15/25] wip --- test/sparsity/test_activation24.py | 78 +++++- torchao/csrc/cuda/activation24/sparse_gemm.cu | 10 +- torchao/csrc/cuda/activation24/sparsify24.cu | 10 +- torchao/ops.py | 20 +- .../sparsity/activation/srelu_linear.py | 41 ++- torchao/sparsity/activation/__init__.py | 0 .../activation/squared_relu_sparse.py | 253 ++++++++++++++++++ torchao/sparsity/sparse_api.py | 189 +------------ 8 files changed, 376 insertions(+), 225 deletions(-) create mode 100644 torchao/sparsity/activation/__init__.py create mode 100644 torchao/sparsity/activation/squared_relu_sparse.py diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index e4bfd44d4f..2c96e6639f 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -121,11 +121,11 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): ) # define reference implementation - def srelu_linear(x): + def reference_srelu(x): x = F.relu(x) ** 2 return reference_linear(x) - reference_srelu = torch.compile(srelu_linear, fullgraph=True) + # reference_srelu = torch.compile(reference_srelu, fullgraph=True) # this only works with fullgraph=True, errors in eager # TODO figure out exactly why this happens @@ -134,9 +134,9 @@ def srelu_linear(x): SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), ) # (reference_linear_copy) - reference_linear_copy.forward = torch.compile( - reference_linear_copy.forward, fullgraph=True - ) + # reference_linear_copy.forward = torch.compile( + # reference_linear_copy.forward, fullgraph=True + # ) reference_output = reference_srelu(input_tensor) custom_output = reference_linear_copy(input_tensor) @@ -144,6 +144,45 @@ def srelu_linear(x): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) +from torchao.sparsity.sparse_api import ActivationSparseLinearConfig +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_asdf(M=512, K=2048, N=1024): + with torch.no_grad(): + torch.manual_seed(0) + input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + # we have to wrap in a sequential block for quantize_ to work properly + reference_linear = torch.nn.Sequential( + torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16) + ) + reference_linear_copy = copy.deepcopy(reference_linear) + + quantize_( + reference_linear, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False) + ), + ) + + # this only works with fullgraph=True, errors in eager + # TODO figure out exactly why this happens + sparsify_( + reference_linear_copy, + ActivationSparseLinearConfig(), + ) + # (reference_linear_copy) + # reference_linear_copy.forward = torch.compile( + # reference_linear_copy.forward, fullgraph=True + # ) + + reference_output = reference_linear(input_tensor) + custom_output = reference_linear_copy(input_tensor) + + print(reference_output) + print(custom_output) + + torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) + + @unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run") def test_splitk_sparse_gemv(): torch.manual_seed(0) @@ -224,3 +263,32 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype ) assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) + +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor_compile( + M=512, N=1024, K=256, dtype=torch.float8_e4m3fn +) -> None: + def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() + x = (x / x_scale).to(dtype) + return x, x_scale + + torch.manual_seed(0) + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A, a_scale = _to_fp8_rowwise(A_dense, dtype) + + B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) + B, b_scale = _to_fp8_rowwise(B_dense, dtype) + + B = B.T + b_scale = b_scale.T + + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale + ) + out_ref = torch._scaled_mm( + A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype + ) + assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 776766794e..9e978edc16 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -343,9 +343,9 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); } -TORCH_LIBRARY_IMPL(torchao, Meta, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), - TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); -} +// TORCH_LIBRARY_IMPL(torchao, Meta, m) { +// m.impl( +// TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), +// TORCH_FN(torchao::_sparse24_fp8_sm90_cutlass_gemm)); +// } #endif diff --git a/torchao/csrc/cuda/activation24/sparsify24.cu b/torchao/csrc/cuda/activation24/sparsify24.cu index e8949fa5d8..b41bf576ca 100644 --- a/torchao/csrc/cuda/activation24/sparsify24.cu +++ b/torchao/csrc/cuda/activation24/sparsify24.cu @@ -412,8 +412,8 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { TORCH_FN(sparse24_sm90_sparsify)); } -TORCH_LIBRARY_IMPL(torchao, Meta, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"), - TORCH_FN(sparse24_sm90_sparsify)); -} +// TORCH_LIBRARY_IMPL(torchao, Meta, m) { +// m.impl( +// TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"), +// TORCH_FN(sparse24_sm90_sparsify)); +// } diff --git a/torchao/ops.py b/torchao/ops.py index b91bb8ae18..ab262a2eff 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -843,26 +843,18 @@ def sparse24_sm90_sparsify( ) -def sparse24_fp8_sm90_cutlass_gemm( +@register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm") +def _( a: Tensor, meta: Tensor, b: Tensor, a_scale: Optional[Tensor], b_scale: Optional[Tensor], - swizzle_size: int, - swizzle_axis: str, - sm_count: int, + swizzle_size: int = 8, + swizzle_axis: str = 'n', + sm_count: int = 128, ) -> Tensor: - return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - a, - meta, - b, - a_scale=a_scale, - b_scale=b_scale, - swizzle_size=swizzle_size, - swizzle_axis=swizzle_axis, - sm_count=sm_count, - ) + return torch.empty(a.shape[0], b.shape[1], dtype=torch.bfloat16, device=a.device) def swizzle_mm( diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index f8c3288b67..e126a10e5f 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -38,6 +38,12 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( ): return FP8SemiSparseActivationLinear.from_dense(module, config) +def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() + x = (x / x_scale).to(dtype) + return x, x_scale + class FP8SemiSparseActivationLinear(nn.Module): """ @@ -48,12 +54,16 @@ def __init__(self, weight, config) -> None: super().__init__() self.config = config - W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype) - self.Wq = W_aqt.tensor_impl.float8_data - self.W_scale = W_aqt.tensor_impl.scale + # W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype) + # self.Wq = W_aqt.tensor_impl.float8_data + # self.W_scale = W_aqt.tensor_impl.scale + W, W_scale = _to_fp8_rowwise(weight, self.config.weight_dtype) + self.W = W + self.W_scale = W_scale def forward(self, x): X_scale = torch.empty([x.shape[0], 1], device=x.device, dtype=torch.float32) + # X_scale = _float8_cutlass_quant(x, self.config.activation_dtype).tensor_impl.scale.repeat([x.shape[0], 1]) Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( x, "cutlass", @@ -62,16 +72,25 @@ def forward(self, x): dtype=self.config.activation_dtype, scale=X_scale, ) - - result = rowwise_scaled_linear_sparse_cutlass_f8f8( - self.Wq, - self.W_scale, + breakpoint() + result = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( Xq_sparse, X_meta, - X_scale, - bias=None, - out_dtype=torch.bfloat16, - ).t() + self.W.T, + a_scale=X_scale, + b_scale=self.W_scale.T, + ) + + + # result = rowwise_scaled_linear_sparse_cutlass_f8f8( + # self.Wq, + # self.W_scale, + # Xq_sparse, + # X_meta, + # X_scale, + # bias=None, + # out_dtype=torch.bfloat16, + # ).t() return result diff --git a/torchao/sparsity/activation/__init__.py b/torchao/sparsity/activation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/sparsity/activation/squared_relu_sparse.py b/torchao/sparsity/activation/squared_relu_sparse.py new file mode 100644 index 0000000000..7d40aeb97d --- /dev/null +++ b/torchao/sparsity/activation/squared_relu_sparse.py @@ -0,0 +1,253 @@ + +import types +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch.sparse import to_sparse_semi_structured + +from torchao.core.config import AOBaseConfig +from torchao.float8.inference import Float8MMConfig +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) +from torchao.quantization.quant_api import ( + _is_linear, + _linear_extra_repr, + _replace_with_custom_fn_if_matches_filter, +) +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) +from torchao.sparsity.blocksparse import BlockSparseTensor +from dataclasses import dataclass + +import torch +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, +) +from torchao.quantization.quant_api import ( + _float8_cutlass_quant, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv +from torch.utils._python_dispatch import return_and_correct_aliasing +def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0].clip(1e-12) / max_v).float() + x = (x.float() / x_scale).clamp(min=-max_v, max=max_v).to(dtype) + return x, x_scale + + +from torchao.utils import TorchAOBaseTensor +from torchao.quantization import LinearActivationQuantizedTensor + +from torchao.float8.float8_utils import tensor_to_scale +from torchao.float8.float8_scaling_utils import ( + get_maybe_axiswise_dim, + hp_tensor_to_float8_dynamic, + hp_tensor_and_scale_to_float8, +) +from torchao.float8.config import CastConfig, ScalingGranularity + +@dataclass +class ActivationSparseLinearConfig(AOBaseConfig): + """ + Adds in acceleration for activation sparsity to linear layers for decode. + + Args: + `activation_dtype`: data type for quantized activation tensor. + `weight_dtype`: data type for quantized weight tensor. + """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + +@register_quantize_module_handler( + ActivationSparseLinearConfig) +def _( + module: torch.nn.Module, + config: ActivationSparseLinearConfig, +): + new_weight = ActivationSparseTensor.from_dense(module.weight.data) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +class ActivationSparseTensor(TorchAOBaseTensor): + data: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + + __slots__ = ["data", "scale"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + data: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + requires_grad: bool = False, + ): + assert data is not None + kwargs = { + "device": data.device, + "dtype": data.dtype, + "layout": data.layout, + "requires_grad": requires_grad, + } + tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.data = data + tensor.scale = scale + return tensor + + def __repr__(self) -> str: # type: ignore[override] + assert hasattr(self, "shape") + return f"{self.__class__.__name__}(shape={self.shape})" + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta, + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, requires_grad = tensor_meta + return cls( + shape=shape, + data=inner_tensors.get("data", None), + scale=inner_tensors.get("scale", None), + requires_grad=requires_grad, + ) + + @classmethod + def from_dense(cls, weight, use_fp8=True): + if use_fp8: + # weight, scale = _to_fp8_rowwise(weight, torch.float8_e4m3fn) + # scale = None + scale = tensor_to_scale( + weight, + torch.float8_e4m3fn, + reduce_amax=False, + device_mesh=None, + scaling_granularity=ScalingGranularity.TENSORWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=False, + ) + x2_lp = hp_tensor_and_scale_to_float8(weight, scale, torch.float8_e4m3fn) + return cls(weight.shape, + data=x2_lp, + scale=None, + requires_grad=False) + else: + return cls(weight.shape, + data=weight.data.t().contiguous().t(), + scale=None, + requires_grad=False) + + def apply_fn_to_shard(self, func): + return ActivationSparseTensor( + shape=self.shape, + data=func(self.data), + scale=func(self.scale), + requires_grad=self.requires_grad, + ) + +# Subclass op dispatch registration +implements = ActivationSparseTensor.implements +aten = torch.ops.aten + + +@implements( + [ + aten.detach.default, + aten.slice.Tensor, + ] +) +def _(func, types, args, kwargs): + new_data = func(args[0].data, *args[1:], **kwargs) + if args[0].scale is None: + new_scale = None + else: + new_scale = func(args[0].scale, *args[1:], **kwargs) + return ActivationSparseTensor( + new_data.shape, + data=new_data, + scale=new_scale, + requires_grad=False, + ) + +@implements( + [aten.copy_.default] +) +def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if not isinstance(src, ActivationSparseTensor): + src_subclass = ActivationSparseTensor.from_dense(src) + + self.data.copy_(src.data) + # slef.scale.copy_(src.scale) + if self.scale is None: + self.scale = None + else: + self.scale.copy_(src.scale) + return + +@implements(torch.nn.functional.linear) +def sparse_activation_linear(func, types, args, kwargs): + x_orig, w, bias = args + assert bias is None + x = x_orig.view(-1, x_orig.size(-1)) + # M = w.shape[0] + # K = w.shape[1] + + if x.shape[0] % 64 != 0: + # w_dequantized = (w.data.to(torch.bfloat16)) + # x_relu = torch.square(torch.nn.functional.relu(x)) + return torch.nn.functional.linear(x_orig, w.data.to_original_precision().to(torch.bfloat16), bias) + # res = torch.ops.torchao.splitk_sparse_gemv(x_relu, + # w.data) + # return res.view(*x_orig.shape[:-1], w.shape[0]) + else: + # X_scale = torch.empty([x.shape[0], 1], dtype=torch.float32, device=x.device) + # Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( + # x, + # "cutlass", + # "identity", + # "largest", + # dtype=torch.float8_e4m3fn, + # scale=X_scale, + # ) + x_ast = ActivationSparseTensor.from_dense(x_orig) # .data.to_original_precision().to(torch.bfloat16) + # x_orig = + # return torch.nn.functional.linear(x_ast, w.data.to_original_precision().to(torch.bfloat16), bias) + breakpoint() + return torch._scaled_mm(x_ast.data._data, w.data._data.T, scale_a=x_ast.data._scale, scale_b=w.data._scale.T, out_dtype=torch.bfloat16) + + + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + Xq_sparse, X_meta, w.data._data.T, a_scale=X_scale, b_scale=w.data._scale.T, + ) + out_sparse = out_sparse.view(*x_orig.shape[:-1], w.shape[0]) + + return out_sparse + + # For normal linear + # x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) + # return torch.nn.functional.linear(x_orig_relu, w.data, bias) diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index aee4894b80..c84333a597 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -11,6 +11,7 @@ from torch.sparse import to_sparse_semi_structured from torchao.core.config import AOBaseConfig +from torchao.float8.inference import Float8MMConfig from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( WeightNormSparsifier, ) @@ -153,190 +154,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: extra_args=(config,), ) -def _to_fp8_rowwise(x: torch.Tensor, dtype): - max_v = torch.finfo(dtype).max - x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() - x = (x / x_scale).to(dtype) - return x, x_scale - -from torchao.utils import TorchAOBaseTensor - -class ActivationSparseTensor(TorchAOBaseTensor): - data: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - - __slots__ = ["data", "scale"] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - data: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - requires_grad: bool = False, - ): - assert data is not None - kwargs = { - "device": data.device, - "dtype": data.dtype, - "layout": data.layout, - "requires_grad": requires_grad, - } - tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - tensor.data = data - tensor.scale = scale - return tensor - - def __repr__(self) -> str: # type: ignore[override] - assert hasattr(self, "shape") - return f"{self.__class__.__name__}(shape={self.shape})" - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta, - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, requires_grad = tensor_meta - return cls( - shape=shape, - data=inner_tensors.get("data", None), - scale=inner_tensors.get("scale", None), - requires_grad=requires_grad, - ) - - @classmethod - def from_dense(cls, weight, use_fp8=True): - if use_fp8: - weight, scale = _to_fp8_rowwise(weight, torch.float8_e4m3fn) - return cls(weight.shape, - data=weight.data.t().contiguous().t(), - scale=scale, - requires_grad=False) - else: - return cls(weight.shape, - data=weight.data.t().contiguous().t(), - scale=None, - requires_grad=False) - - def apply_fn_to_shard(self, func): - return ActivationSparseTensor( - shape=self.shape, - data=func(self.data), - scale=func(self.scale), - requires_grad=self.requires_grad, - ) - -# Subclass op dispatch registration -implements = ActivationSparseTensor.implements -aten = torch.ops.aten - - -@implements( - [ - aten.detach.default, - aten.slice.Tensor, - ] -) -def _(func, types, args, kwargs): - new_data = func(args[0].data, *args[1:], **kwargs) - if args[0].scale is None: - new_scale = None - else: - new_scale = func(args[0].scale, *args[1:], **kwargs) - return ActivationSparseTensor( - new_data.shape, - data=new_data, - scale=new_scale, - requires_grad=False, - ) -@implements( - [aten.copy_.default] -) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - self.data.copy_(src.data) - self.scale.copy_(src.scale) - return - -@implements(torch.nn.functional.linear) -def sparse_activation_linear(func, types, args, kwargs): - x_orig, w, bias = args - assert bias is None - x = x_orig.view(-1, x_orig.size(-1)) - # M = w.shape[0] - # K = w.shape[1] - - if x.shape[0] == 1: - x_relu = torch.square(torch.nn.functional.relu(x)) - res = torch.ops.torchao.splitk_sparse_gemv(x_relu, - w.data) - return res.view(*x_orig.shape[:-1], w.shape[0]) - else: - print(x.shape) - X_scale = torch.empty([x.shape[0], 1], dtype=torch.float32, device=x.device) - Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( - x, - "cutlass", - "srelu", - "largest", - dtype=torch.float8_e4m3fn, - scale=X_scale, - ) - X_scale_squeeze = X_scale.squeeze() - - breakpoint() - - result = rowwise_scaled_linear_sparse_cutlass_f8f8( - w.data, - w.scale.squeeze(), - Xq_sparse, - X_meta, - X_scale_squeeze, - bias=None, - out_dtype=torch.bfloat16, - ).t() - - breakpoint() - - return result - - # For normal linear - # x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) - # return torch.nn.functional.linear(x_orig_relu, w.data, bias) - -@dataclass -class ActivationSparseLinearConfig(AOBaseConfig): - """ - Adds in acceleration for activation sparsity to linear layers for decode. - - Args: - `activation_dtype`: data type for quantized activation tensor. - `weight_dtype`: data type for quantized weight tensor. - """ - - activation_dtype: torch.dtype = torch.float8_e4m3fn - weight_dtype: torch.dtype = torch.float8_e4m3fn - -@register_quantize_module_handler( - ActivationSparseLinearConfig) -def _( - module: torch.nn.Module, - config: ActivationSparseLinearConfig, -): - new_weight = ActivationSparseTensor.from_dense(module.weight.data) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module +from torchao.sparsity.activation.squared_relu_sparse import ( + ActivationSparseLinearConfig +) From 3c16b7d50d9585489b244ce0e39ce04988106af0 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 14:54:44 -0700 Subject: [PATCH 16/25] update --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 55 ++- test/sparsity/test_activation24.py | 33 +- torchao/csrc/cuda/activation24/sparse_gemm.cu | 21 +- torchao/csrc/cuda/activation24/sparsify24.cu | 10 +- torchao/dtypes/affine_quantized_tensor.py | 15 + torchao/dtypes/affine_quantized_tensor_ops.py | 6 + .../floatx/cutlass_semi_sparse_layout.py | 135 +++++-- torchao/ops.py | 24 +- .../sparsity/activation/srelu_linear.py | 44 +-- torchao/quantization/quant_primitives.py | 2 + .../activation/squared_relu_sparse.py | 341 +++++++++++++----- torchao/sparsity/sparse_api.py | 3 +- 12 files changed, 474 insertions(+), 215 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index fbab8c0671..f74972f998 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -12,6 +12,9 @@ from torchao.prototype.sparsity.activation.srelu_linear import ( SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, ) +from torchao.sparsity.sparse_api import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig +) from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( Float8DynamicActivationFloat8SemiSparseWeightConfig, @@ -21,12 +24,14 @@ quantize_, ) +PROFILE = False + def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 -def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): +def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): ffn_ref = ( nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), @@ -73,24 +78,31 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor) # fp8 sparse - ffn_clone = ( - nn.Sequential( - nn.Linear(hidden_size, intermediate_size, bias=False), - SquaredReLU(), - nn.Linear(intermediate_size, hidden_size, bias=False), - ) - .to(torch.bfloat16) - .cuda() - ) - quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) - ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) - fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) + # ffn_clone = ( + # nn.Sequential( + # nn.Linear(hidden_size, intermediate_size, bias=False), + # SquaredReLU(), + # nn.Linear(intermediate_size, hidden_size, bias=False), + # ) + # .to(torch.bfloat16) + # .cuda() + # ) + # quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + # ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) + # fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) + + if PROFILE: + print("PROFILING FP8") + from torchao.prototype.sparsity.activation.utils import profiler_runner + inputs = (ffn_clone, input_tensor) + profiler_runner(None, benchmark_microseconds, *inputs) # activation fp8 sparse ffn_clone = ( nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), # no Squared RELU since it will be fused into the second linear + SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) .to(torch.bfloat16) @@ -104,18 +116,26 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): ) quantize_( ffn_clone, - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), - filter_fn=lambda mod, fqn: "1" in fqn, + Float8DynamicSemiSparseActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), + filter_fn=lambda mod, fqn: "2" in fqn, ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) + if PROFILE: + print("PROFILING 24") + from torchao.prototype.sparsity.activation.utils import profiler_runner + inputs = (ffn_clone, input_tensor) + profiler_runner(None, benchmark_microseconds, *inputs) + return { "num_tokens": num_tokens, "bf16_latency (us)": fp16_time, "bf16_c_latency (us)": fp16_c_time, "fp8_c_time (us)": fp8_c_time, - "fp8_c_sparse_time (us)": fp8_c_sparse_time, + # "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, } @@ -124,7 +144,8 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192): if __name__ == "__main__": with torch.no_grad(): results = [] - for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): + # for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): + for num_tokens in tqdm([512, 1024, 2048, 4096, 8192]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 2c96e6639f..05dd4edec5 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -9,6 +9,7 @@ quantize_, ) from torchao.quantization.quant_api import _float8_cutlass_quant +from torchao.sparsity.activation.squared_relu_sparse import Float8DynamicSemiSparseActivationFloat8WeightConfig torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -116,7 +117,7 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): quantize_( reference_linear, Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False) + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) @@ -125,7 +126,7 @@ def reference_srelu(x): x = F.relu(x) ** 2 return reference_linear(x) - # reference_srelu = torch.compile(reference_srelu, fullgraph=True) + reference_srelu = torch.compile(reference_srelu, fullgraph=True) # this only works with fullgraph=True, errors in eager # TODO figure out exactly why this happens @@ -134,19 +135,22 @@ def reference_srelu(x): SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), ) # (reference_linear_copy) - # reference_linear_copy.forward = torch.compile( - # reference_linear_copy.forward, fullgraph=True - # ) + reference_linear_copy.forward = torch.compile( + reference_linear_copy.forward, fullgraph=True + ) reference_output = reference_srelu(input_tensor) custom_output = reference_linear_copy(input_tensor) + print(reference_output) + print(custom_output) + torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) from torchao.sparsity.sparse_api import ActivationSparseLinearConfig @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_asdf(M=512, K=2048, N=1024): +def test_asdf(M=16384, K=2048, N=1024): with torch.no_grad(): torch.manual_seed(0) input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() @@ -159,26 +163,29 @@ def test_asdf(M=512, K=2048, N=1024): quantize_( reference_linear, Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False) + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) + reference_linear.forward = torch.compile(reference_linear.forward, fullgraph=True) # this only works with fullgraph=True, errors in eager # TODO figure out exactly why this happens sparsify_( reference_linear_copy, - ActivationSparseLinearConfig(), + Float8DynamicSemiSparseActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), ) # (reference_linear_copy) - # reference_linear_copy.forward = torch.compile( - # reference_linear_copy.forward, fullgraph=True - # ) + reference_linear_copy.forward = torch.compile( + reference_linear_copy.forward, fullgraph=True, + ) reference_output = reference_linear(input_tensor) custom_output = reference_linear_copy(input_tensor) - print(reference_output) - print(custom_output) + print(reference_output.is_contiguous()) + print(custom_output.is_contiguous()) torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 9e978edc16..7bd49696c5 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -169,7 +169,7 @@ struct SparseRowwiseKernel { cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, TileShape, - cute::Shape, + cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto, float, float, @@ -193,8 +193,8 @@ struct SparseRowwiseKernel { cutlass::layout::ColumnMajor, 16, float, - cute::Shape, - cute::Shape, + cute::Shape, + cute::Shape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; @@ -272,6 +272,11 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm( {cute::get<0>(args.problem_shape), cute::get<1>(args.problem_shape)}, at::TensorOptions().dtype(K::kElementOutAt)); + // meta registration + if (kIsMeta) { + return out; + } + args.mainloop.ptr_A = reinterpret_cast(tensor_a.data_ptr()); args.mainloop.ptr_B = static_cast(tensor_b.data_ptr()); @@ -343,9 +348,9 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); } -// TORCH_LIBRARY_IMPL(torchao, Meta, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), -// TORCH_FN(torchao::_sparse24_fp8_sm90_cutlass_gemm)); -// } +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), + TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); +} #endif diff --git a/torchao/csrc/cuda/activation24/sparsify24.cu b/torchao/csrc/cuda/activation24/sparsify24.cu index b41bf576ca..e8949fa5d8 100644 --- a/torchao/csrc/cuda/activation24/sparsify24.cu +++ b/torchao/csrc/cuda/activation24/sparsify24.cu @@ -412,8 +412,8 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { TORCH_FN(sparse24_sm90_sparsify)); } -// TORCH_LIBRARY_IMPL(torchao, Meta, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"), -// TORCH_FN(sparse24_sm90_sparsify)); -// } +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::sparse24_sm90_sparsify"), + TORCH_FN(sparse24_sm90_sparsify)); +} diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 6cb2e8997e..748fef7ceb 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -458,10 +458,25 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): + from torchao.dtypes.floatx.cutlass_semi_sparse_layout import CutlassSemiSparseLayout """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) + + # handle CUTLASS specially + if isinstance(_layout, CutlassSemiSparseLayout): + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(input_float, None, None, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + dtype=input_float.dtype, + ) + + + scale = choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index e9702a33ac..bfafc352d0 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -14,6 +14,8 @@ from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, + _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_impl, ) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, @@ -191,6 +193,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl, ), + ( + _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_impl, + ), ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 45fe451712..f5377034b7 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -15,7 +15,7 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.ops import ( rowwise_scaled_linear_sparse_cutlass_f8f8, to_sparse_semi_structured_cutlass_sm9x_f8, @@ -42,11 +42,11 @@ def _same_metadata( class CutlassSemiSparseLayout(Layout): """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" - def pre_process(self, dense: torch.Tensor) -> torch.Tensor: - # prune to 2:4 if not already - from torchao.sparsity.utils import mask_creator + # def pre_process(self, dense: torch.Tensor) -> torch.Tensor: + # # prune to 2:4 if not already + # from torchao.sparsity.utils import mask_creator - return dense * mask_creator(dense).bool() + # return dense * mask_creator(dense).bool() @register_layout(CutlassSemiSparseLayout) @@ -54,6 +54,7 @@ class CutlassSemiSparseTensorImpl(AQTTensorImpl): @staticmethod def __new__( cls, + shape: torch.Size, sparse: torch.Tensor, meta: torch.Tensor, scale: torch.Tensor, @@ -66,11 +67,12 @@ def __new__( ) kwargs["dtype"] = sparse.dtype kwargs["requires_grad"] = False - shape = (sparse.shape[0], 2 * sparse.shape[-1]) + # shape = (sparse.shape[0], 2 * sparse.shape[-1]) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, + shape: torch.Size, sparse: torch.Tensor, meta: torch.Tensor, scale: torch.Tensor, @@ -80,6 +82,7 @@ def __init__( self.meta = meta self.scale = scale self._layout = _layout + self._shape = shape @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -106,7 +109,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["sparse", "meta", "scale"], [self._layout] + return ["sparse", "meta", "scale"], [self._layout, self._shape] @classmethod def __tensor_unflatten__( @@ -115,35 +118,37 @@ def __tensor_unflatten__( sparse = tensor_data_dict["sparse"] meta = tensor_data_dict["meta"] scale = tensor_data_dict["scale"] - (_layout,) = tensor_attributes - return cls(sparse, meta, scale, _layout) + (_layout, _shape) = tensor_attributes + return cls(_shape, sparse, meta, scale, _layout) def get_plain(self): # No support in CUTLASS to convert back to dense from sparse # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. - cols = self.shape[1] - input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) - input_scale = torch.ones( - (cols,), dtype=self.scale.dtype, device=self.sparse.device - ) - sparse_scale = torch.ones_like(self.scale) - out_dtype = torch.bfloat16 - dense = ( - rowwise_scaled_linear_sparse_cutlass_f8f8( - input, - input_scale, - self.sparse, - self.meta, - sparse_scale, - out_dtype=out_dtype, - ) - .to(self.dtype) - .t() - .contiguous() - ) - - return dense, self.scale, None + # breakpoint() + raise NotImplementedError("get_plain not supported for CutlassSemiSparseTensorImpl") + # cols = self.shape[-1] + # input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) + # input_scale = torch.ones( + # (cols,), dtype=self.scale.dtype, device=self.sparse.device + # ) + # sparse_scale = torch.ones_like(self.scale) + # out_dtype = torch.bfloat16 + # dense = ( + # rowwise_scaled_linear_sparse_cutlass_f8f8( + # input, + # input_scale, + # self.sparse, + # self.meta, + # sparse_scale, + # out_dtype=out_dtype, + # ) + # .to(self.dtype) + # .t() + # .contiguous() + # ) + + # return dense, self.scale, None @classmethod def from_plain( @@ -154,13 +159,24 @@ def from_plain( _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) - - sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense) + # print(dense.shape) + dense_2d = dense.view(-1, dense.shape[-1]) + + X_scale = torch.empty((dense_2d.shape[0], 1), device=dense.device, dtype=torch.float32) + Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( + dense_2d, + "cutlass", + "identity", + "largest", + dtype=torch.float8_e4m3fn, + scale=X_scale, + ) return cls( - sparse, - meta, - scale, + dense.shape, + Xq_sparse, + X_meta, + X_scale, _layout, ) @@ -210,3 +226,50 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, ) return out + +def _linear_fp8_act_sparse_fp8_weight_cutlass_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import Float8Layout + + # if isinstance(input_tensor, AffineQuantizedTensor) and isinstance(input_tensor._layout, CutlassSemiSparseLayout): + # breakpoint() + + res = ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassSemiSparseLayout) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == 2 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 2 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + return res + +def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + + input_sparse = input_tensor.tensor_impl.sparse + input_meta = input_tensor.tensor_impl.meta + input_scale = input_tensor.tensor_impl.scale + weight = weight_tensor.tensor_impl.float8_data + weight_scale = weight_tensor.tensor_impl.scale + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + out_dtype = input_tensor.dtype + + # out = rowwise_scaled_linear_sparse_cutlass_f8f8( + # weight, weight_scale, input_sparse, input_meta, input_scale, bias, out_dtype + # ).t().view(out_shape) + + out= torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + input_sparse, input_meta, weight.t(), a_scale=input_scale, b_scale=weight_scale.t(), + ).view(out_shape) + + return out diff --git a/torchao/ops.py b/torchao/ops.py index ab262a2eff..a44779f2ec 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -843,18 +843,18 @@ def sparse24_sm90_sparsify( ) -@register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm") -def _( - a: Tensor, - meta: Tensor, - b: Tensor, - a_scale: Optional[Tensor], - b_scale: Optional[Tensor], - swizzle_size: int = 8, - swizzle_axis: str = 'n', - sm_count: int = 128, -) -> Tensor: - return torch.empty(a.shape[0], b.shape[1], dtype=torch.bfloat16, device=a.device) +# @register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm") +# def _( +# a: Tensor, +# meta: Tensor, +# b: Tensor, +# a_scale: Optional[Tensor], +# b_scale: Optional[Tensor], +# swizzle_size: int = 8, +# swizzle_axis: str = 'n', +# sm_count: int = 128, +# ) -> Tensor: +# return torch.empty(a.shape[0], b.shape[1], dtype=torch.bfloat16, device=a.device) def swizzle_mm( diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index e126a10e5f..3de8edb15b 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -38,12 +38,6 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( ): return FP8SemiSparseActivationLinear.from_dense(module, config) -def _to_fp8_rowwise(x: torch.Tensor, dtype): - max_v = torch.finfo(dtype).max - x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() - x = (x / x_scale).to(dtype) - return x, x_scale - class FP8SemiSparseActivationLinear(nn.Module): """ @@ -54,44 +48,42 @@ def __init__(self, weight, config) -> None: super().__init__() self.config = config - # W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype) - # self.Wq = W_aqt.tensor_impl.float8_data - # self.W_scale = W_aqt.tensor_impl.scale - W, W_scale = _to_fp8_rowwise(weight, self.config.weight_dtype) - self.W = W - self.W_scale = W_scale + W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype) + self.W = W_aqt.tensor_impl.float8_data + self.W_scale = W_aqt.tensor_impl.scale def forward(self, x): + # breakpoint() + # print(x) X_scale = torch.empty([x.shape[0], 1], device=x.device, dtype=torch.float32) - # X_scale = _float8_cutlass_quant(x, self.config.activation_dtype).tensor_impl.scale.repeat([x.shape[0], 1]) Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( x, "cutlass", - "srelu", + "identity", "largest", dtype=self.config.activation_dtype, scale=X_scale, ) - breakpoint() - result = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - Xq_sparse, - X_meta, - self.W.T, - a_scale=X_scale, - b_scale=self.W_scale.T, - ) - # result = rowwise_scaled_linear_sparse_cutlass_f8f8( - # self.Wq, - # self.W_scale, + # self.W, + # self.W_scale.squeeze(), # Xq_sparse, # X_meta, - # X_scale, + # X_scale.squeeze(), # bias=None, # out_dtype=torch.bfloat16, # ).t() + # result = + result = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + Xq_sparse, + X_meta, + self.W.t(), + a_scale=X_scale, + b_scale=self.W_scale.t(), + ) + return result @classmethod diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index cee8df21a2..446f512c65 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -674,6 +674,8 @@ def _dequantize_affine_no_dtype_check( 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain 3. reshape the quantized result to origianl shape and change dtype to the output_dtype """ + if len(block_size) != input.dim(): + breakpoint() assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" ) diff --git a/torchao/sparsity/activation/squared_relu_sparse.py b/torchao/sparsity/activation/squared_relu_sparse.py index 7d40aeb97d..0714d89ca7 100644 --- a/torchao/sparsity/activation/squared_relu_sparse.py +++ b/torchao/sparsity/activation/squared_relu_sparse.py @@ -1,44 +1,39 @@ - import types from dataclasses import dataclass -from typing import Callable, Optional +from typing import List, Optional, Union import torch -from torch.sparse import to_sparse_semi_structured +import torchao from torchao.core.config import AOBaseConfig -from torchao.float8.inference import Float8MMConfig -from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( - WeightNormSparsifier, +from torchao.dtypes import ( + CutlassSemiSparseLayout, + Float8Layout, + to_affine_quantized_floatx, ) -from torchao.quantization.quant_api import ( - _is_linear, - _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, -) -from torchao.quantization.transform_module import ( - _QUANTIZE_CONFIG_HANDLER, - register_quantize_module_handler, -) -from torchao.sparsity.blocksparse import BlockSparseTensor -from dataclasses import dataclass - -import torch -from torch import nn - -from torchao.core.config import AOBaseConfig -from torchao.ops import ( - rowwise_scaled_linear_sparse_cutlass_f8f8, +from torchao.float8.config import e4m3_dtype +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _check_hardware_support, + _normalize_granularity, ) +from torchao.quantization.observer import get_block_size from torchao.quantization.quant_api import ( + PerRow, _float8_cutlass_quant, + _linear_extra_repr, + to_linear_activation_quantized, ) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) +from torchao.utils import ( + is_MI300, + is_sm_at_least_89, +) + -from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv -from torch.utils._python_dispatch import return_and_correct_aliasing def _to_fp8_rowwise(x: torch.Tensor, dtype): max_v = torch.finfo(dtype).max x_scale = (x.abs().max(1, keepdim=True)[0].clip(1e-12) / max_v).float() @@ -47,20 +42,12 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): from torchao.utils import TorchAOBaseTensor -from torchao.quantization import LinearActivationQuantizedTensor -from torchao.float8.float8_utils import tensor_to_scale -from torchao.float8.float8_scaling_utils import ( - get_maybe_axiswise_dim, - hp_tensor_to_float8_dynamic, - hp_tensor_and_scale_to_float8, -) -from torchao.float8.config import CastConfig, ScalingGranularity @dataclass class ActivationSparseLinearConfig(AOBaseConfig): """ - Adds in acceleration for activation sparsity to linear layers for decode. + Adds in acceleration for activation sparsity to linear layers for decode. Args: `activation_dtype`: data type for quantized activation tensor. @@ -70,8 +57,10 @@ class ActivationSparseLinearConfig(AOBaseConfig): activation_dtype: torch.dtype = torch.float8_e4m3fn weight_dtype: torch.dtype = torch.float8_e4m3fn -@register_quantize_module_handler( - ActivationSparseLinearConfig) + mm_config = Float8MMConfig(use_fast_accum=True) + + +@register_quantize_module_handler(ActivationSparseLinearConfig) def _( module: torch.nn.Module, config: ActivationSparseLinearConfig, @@ -138,27 +127,17 @@ def __tensor_unflatten__( @classmethod def from_dense(cls, weight, use_fp8=True): if use_fp8: - # weight, scale = _to_fp8_rowwise(weight, torch.float8_e4m3fn) - # scale = None - scale = tensor_to_scale( - weight, - torch.float8_e4m3fn, - reduce_amax=False, - device_mesh=None, - scaling_granularity=ScalingGranularity.TENSORWISE, - axiswise_dim=-1, - round_scales_to_power_of_2=False, - ) - x2_lp = hp_tensor_and_scale_to_float8(weight, scale, torch.float8_e4m3fn) - return cls(weight.shape, - data=x2_lp, - scale=None, - requires_grad=False) + W_aqt = _float8_cutlass_quant(weight, torch.float8_e4m3fn) + W = W_aqt.tensor_impl.float8_data + W_scale = W_aqt.tensor_impl.scale + return cls(weight.shape, data=W, scale=W_scale, requires_grad=False) else: - return cls(weight.shape, - data=weight.data.t().contiguous().t(), - scale=None, - requires_grad=False) + return cls( + weight.shape, + data=weight.data.t().contiguous().t(), + scale=None, + requires_grad=False, + ) def apply_fn_to_shard(self, func): return ActivationSparseTensor( @@ -168,6 +147,7 @@ def apply_fn_to_shard(self, func): requires_grad=self.requires_grad, ) + # Subclass op dispatch registration implements = ActivationSparseTensor.implements aten = torch.ops.aten @@ -192,62 +172,229 @@ def _(func, types, args, kwargs): requires_grad=False, ) -@implements( - [aten.copy_.default] -) + +@implements([aten.copy_.default]) def _(func, types, args, kwargs): self = args[0] src = args[1] if not isinstance(src, ActivationSparseTensor): src_subclass = ActivationSparseTensor.from_dense(src) + self.data.copy_(src_subclass.data) + self.scale.copy_(src_subclass.scale) + return + - self.data.copy_(src.data) - # slef.scale.copy_(src.scale) - if self.scale is None: - self.scale = None +def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert dense_input.dim() == 2 + + # check shape + m, n = dense_input.shape + min_rows = 64 + min_cols = 64 + + # calculate padding + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) else: - self.scale.copy_(src.scale) - return + return dense_input + @implements(torch.nn.functional.linear) def sparse_activation_linear(func, types, args, kwargs): x_orig, w, bias = args assert bias is None x = x_orig.view(-1, x_orig.size(-1)) - # M = w.shape[0] - # K = w.shape[1] - - if x.shape[0] % 64 != 0: - # w_dequantized = (w.data.to(torch.bfloat16)) - # x_relu = torch.square(torch.nn.functional.relu(x)) - return torch.nn.functional.linear(x_orig, w.data.to_original_precision().to(torch.bfloat16), bias) - # res = torch.ops.torchao.splitk_sparse_gemv(x_relu, - # w.data) - # return res.view(*x_orig.shape[:-1], w.shape[0]) + m, n = x.shape + + # # # if x input is the right shape, we use sparse matmul + # x_padded = _pad_dense_input(x) + # if (x.size(0) % 64) == 0: + # if (x.size(0) == 64) or (x.size(0) == 128) or (x.size(0) ==256) or (x.size(0)==512): + if False: + X_scale = torch.empty( + [x.shape[0], 1], dtype=torch.float32, device=x_orig.device + ) + Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( + x, + "cutlass", + "identity", + "largest", + dtype=torch.float8_e4m3fn, + scale=X_scale, + ) + + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + Xq_sparse, + X_meta, + w.data.t(), + a_scale=X_scale, + b_scale=w.scale.t(), + ) + # print(out_sparse.shape) + out_sparse = out_sparse.reshape(*x_orig.shape[:-1], w.shape[0]) + return out_sparse else: - # X_scale = torch.empty([x.shape[0], 1], dtype=torch.float32, device=x.device) - # Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( - # x, - # "cutlass", - # "identity", - # "largest", - # dtype=torch.float8_e4m3fn, - # scale=X_scale, - # ) - x_ast = ActivationSparseTensor.from_dense(x_orig) # .data.to_original_precision().to(torch.bfloat16) - # x_orig = - # return torch.nn.functional.linear(x_ast, w.data.to_original_precision().to(torch.bfloat16), bias) - breakpoint() - return torch._scaled_mm(x_ast.data._data, w.data._data.T, scale_a=x_ast.data._scale, scale_b=w.data._scale.T, out_dtype=torch.bfloat16) + w_dequantized = (w.data.to(torch.float32) * w.scale).to(torch.bfloat16) + return torch.nn.functional.linear(x_orig, w_dequantized, bias) - out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - Xq_sparse, X_meta, w.data._data.T, a_scale=X_scale, b_scale=w.data._scale.T, +from torchao.quantization.quant_api import ( + Float8Layout, + _check_hardware_support, + _fp8_mm_compat, + to_affine_quantized_floatx, +) + + +@dataclass +class Float8DynamicSemiSparseActivationFloat8WeightConfig(AOBaseConfig): + """ + Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. + + Args: + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + granularity: + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And + only PerTensor and PerRow are supported. + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + + """ + + activation_dtype: torch.dtype = e4m3_dtype + weight_dtype: torch.dtype = e4m3_dtype + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None + mm_config: Optional[Float8MMConfig] = None + set_inductor_config: bool = True + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + activation_granularity, weight_granularity = _normalize_granularity( + self.granularity + ) + self.granularity = [activation_granularity, weight_granularity] + + +def _float8_dynamic_sparse_activation_float8_weight_quantize_tensor(weight, config): + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + # Ensure works on device + _check_hardware_support(granularity) + activation_granularity, weight_granularity = granularity + + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return weight + if isinstance(weight_granularity, PerRow): + assert weight.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input weight" ) - out_sparse = out_sparse.view(*x_orig.shape[:-1], w.shape[0]) + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + input_quant_func = _input_activation_quant_func_fp8_sparse + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + return quantized_weight - return out_sparse - # For normal linear - # x_orig_relu = torch.square(torch.nn.functional.relu(x_orig)) - # return torch.nn.functional.linear(x_orig_relu, w.data, bias) +@register_quantize_module_handler(Float8DynamicSemiSparseActivationFloat8WeightConfig) +def _float8_dynamic_activation_sparse_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicSemiSparseActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_sparse_activation_float8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +def _input_activation_quant_func_fp8_sparse( + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: Optional[torch.Tensor] = None, + zero_point: Optional[torch.Tensor] = None, +): + """This function is used to quantize the input activation tensor for an aqt_float variant. If scale + is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. + """ + x_2d = x.view(-1, x.size(-1)) + + assert zero_point is None, ( + "Zero point is not supported for dynamic FP8 quantization" + ) + if isinstance(activation_granularity, PerRow): + assert x.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input activation" + ) + + if ( + (x_2d.size(0) == 64) or + (x_2d.size(0) == 128) or + (x_2d.size(0) == 192) or + (x_2d.size(0) == 256) or + (x_2d.size(0) == 320) or + (x_2d.size(0) == 384) or + (x_2d.size(0) == 448) or + (x_2d.size(0) == 512) or + (x_2d.size(0) == 1024) or + (x_2d.size(0) == 2048) or + (x_2d.size(0) == 4096) or + (x_2d.size(0) == 8192) + ): + layout=CutlassSemiSparseLayout() + else: + layout=Float8Layout(mm_config=None) + + block_size = get_block_size(x.shape, activation_granularity) + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=layout, + ) + return activation diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index c84333a597..0fcf04fdac 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -157,5 +157,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: from torchao.sparsity.activation.squared_relu_sparse import ( - ActivationSparseLinearConfig + ActivationSparseLinearConfig, + Float8DynamicSemiSparseActivationFloat8WeightConfig, ) From fe8d22ef0fc55e799c3c71c6b489464fbd730953 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 18:47:42 -0700 Subject: [PATCH 17/25] eager working --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 8 +- test/sparsity/test_activation24.py | 12 +- torchao/csrc/cuda/activation24/sparse_gemm.cu | 6 +- torchao/dtypes/affine_quantized_tensor.py | 5 +- .../floatx/cutlass_semi_sparse_layout.py | 122 ++++++++++++------ .../activation/squared_relu_sparse.py | 75 +++++------ 6 files changed, 128 insertions(+), 100 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index f74972f998..a9f643b5e3 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -102,7 +102,7 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), # no Squared RELU since it will be fused into the second linear - SquaredReLU(), + # SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) .to(torch.bfloat16) @@ -115,11 +115,11 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): ), ) quantize_( - ffn_clone, + ffn_clone[1], Float8DynamicSemiSparseActivationFloat8WeightConfig( granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), - filter_fn=lambda mod, fqn: "2" in fqn, + # filter_fn=lambda mod, fqn: "1" in fqn, ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) @@ -145,7 +145,7 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): with torch.no_grad(): results = [] # for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): - for num_tokens in tqdm([512, 1024, 2048, 4096, 8192]): + for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 05dd4edec5..f4e12d8ebc 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -150,7 +150,7 @@ def reference_srelu(x): from torchao.sparsity.sparse_api import ActivationSparseLinearConfig @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_asdf(M=16384, K=2048, N=1024): +def test_asdf(M=1, K=16384, N=4096): with torch.no_grad(): torch.manual_seed(0) input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() @@ -166,20 +166,20 @@ def test_asdf(M=16384, K=2048, N=1024): granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) - reference_linear.forward = torch.compile(reference_linear.forward, fullgraph=True) + # reference_linear.forward = torch.compile(reference_linear.forward) # this only works with fullgraph=True, errors in eager # TODO figure out exactly why this happens - sparsify_( + quantize_( reference_linear_copy, Float8DynamicSemiSparseActivationFloat8WeightConfig( granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) # (reference_linear_copy) - reference_linear_copy.forward = torch.compile( - reference_linear_copy.forward, fullgraph=True, - ) + # reference_linear_copy.forward = torch.compile( + # reference_linear_copy.forward, + # ) reference_output = reference_linear(input_tensor) custom_output = reference_linear_copy(input_tensor) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 7bd49696c5..7ed7702c7e 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -169,7 +169,7 @@ struct SparseRowwiseKernel { cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, TileShape, - cute::Shape, + cute::Shape, cutlass::epilogue::collective::EpilogueTileAuto, float, float, @@ -193,8 +193,8 @@ struct SparseRowwiseKernel { cutlass::layout::ColumnMajor, 16, float, - cute::Shape, - cute::Shape, + cute::Shape, + cute::Shape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 748fef7ceb..3b2569033f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -466,8 +466,11 @@ def from_hp_to_floatx( # handle CUTLASS specially if isinstance(_layout, CutlassSemiSparseLayout): + scale = choose_qparams_affine_float8( + input_float, float8_dtype=target_dtype, block_size=block_size + ) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(input_float, None, None, _layout) + tensor_impl = tensor_impl_ctr(input_float, scale, None, _layout) return cls( tensor_impl, block_size, diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index f5377034b7..3132c6774c 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -23,6 +23,48 @@ aten = torch.ops.aten +def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert dense_input.dim() == 2 + + # check shape + m, n = dense_input.shape + min_rows = 64 + min_cols = 64 + + # calculate padding + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m or to_pad_n: + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) + else: + return dense_input + +def _pad_scale(scale: torch.Tensor) -> torch.Tensor: + """ + Calculates padding for dense tensor and pads tensor if necessary. + If padding is not required, this function returns the original tensor. + """ + # only 2d matmul + assert scale.dim() == 2 + + # check shape + m, n = scale.shape + assert n == 1 + min_rows = 64 + # min_cols = 64 + + # calculate padding + to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 + # to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 + if to_pad_m: + return torch.nn.functional.pad(scale, (0, 0, 0, to_pad_m)) + else: + return scale def _same_metadata( self: "CutlassSemiSparseTensorImpl", src: "CutlassSemiSparseTensorImpl" @@ -126,29 +168,29 @@ def get_plain(self): # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. # breakpoint() - raise NotImplementedError("get_plain not supported for CutlassSemiSparseTensorImpl") - # cols = self.shape[-1] - # input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) - # input_scale = torch.ones( - # (cols,), dtype=self.scale.dtype, device=self.sparse.device - # ) - # sparse_scale = torch.ones_like(self.scale) - # out_dtype = torch.bfloat16 - # dense = ( - # rowwise_scaled_linear_sparse_cutlass_f8f8( - # input, - # input_scale, - # self.sparse, - # self.meta, - # sparse_scale, - # out_dtype=out_dtype, - # ) - # .to(self.dtype) - # .t() - # .contiguous() - # ) - - # return dense, self.scale, None + # raise NotImplementedError("get_plain not supported for CutlassSemiSparseTensorImpl") + cols = self.shape[-1] + input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) + input_scale = torch.ones( + (cols,), dtype=self.scale.dtype, device=self.sparse.device + ) + sparse_scale = torch.ones_like(self.scale) + out_dtype = torch.bfloat16 + dense = ( + rowwise_scaled_linear_sparse_cutlass_f8f8( + input, + input_scale, + self.sparse, + self.meta, + sparse_scale, + out_dtype=out_dtype, + ) + .to(self.dtype) + .t() + .contiguous() + ) + + return dense, self.scale, None @classmethod def from_plain( @@ -160,25 +202,31 @@ def from_plain( ): assert zero_point is None or torch.all(zero_point == 0) # print(dense.shape) - dense_2d = dense.view(-1, dense.shape[-1]) + # dense_2d = dense.view(-1, dense.shape[-1]) + assert dense.ndim == 2 + assert dense.is_contiguous() + + dense_padded = _pad_dense_input(dense) + scale_padded = _pad_scale(scale) - X_scale = torch.empty((dense_2d.shape[0], 1), device=dense.device, dtype=torch.float32) + # X_scale = torch.empty((dense.shape[0], 1), device=dense.device, dtype=torch.float32) Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( - dense_2d, + dense_padded, "cutlass", "identity", "largest", dtype=torch.float8_e4m3fn, - scale=X_scale, + scale=scale_padded, ) - return cls( + res = cls( dense.shape, Xq_sparse, X_meta, - X_scale, + scale_padded, _layout, ) + return res def get_layout(self) -> Layout: return self._layout @@ -230,9 +278,6 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, def _linear_fp8_act_sparse_fp8_weight_cutlass_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import Float8Layout - # if isinstance(input_tensor, AffineQuantizedTensor) and isinstance(input_tensor._layout, CutlassSemiSparseLayout): - # breakpoint() - res = ( isinstance(input_tensor, AffineQuantizedTensor) and isinstance(input_tensor._layout, CutlassSemiSparseLayout) @@ -261,15 +306,10 @@ def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, weight_scale = weight_tensor.tensor_impl.scale out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + rows, cols = (input_tensor.shape) - out_dtype = input_tensor.dtype - - # out = rowwise_scaled_linear_sparse_cutlass_f8f8( - # weight, weight_scale, input_sparse, input_meta, input_scale, bias, out_dtype - # ).t().view(out_shape) - - out= torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + out = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( input_sparse, input_meta, weight.t(), a_scale=input_scale, b_scale=weight_scale.t(), - ).view(out_shape) - + )[:rows, :].view(out_shape) + return out diff --git a/torchao/sparsity/activation/squared_relu_sparse.py b/torchao/sparsity/activation/squared_relu_sparse.py index 0714d89ca7..7424a75514 100644 --- a/torchao/sparsity/activation/squared_relu_sparse.py +++ b/torchao/sparsity/activation/squared_relu_sparse.py @@ -33,13 +33,7 @@ is_sm_at_least_89, ) - -def _to_fp8_rowwise(x: torch.Tensor, dtype): - max_v = torch.finfo(dtype).max - x_scale = (x.abs().max(1, keepdim=True)[0].clip(1e-12) / max_v).float() - x = (x.float() / x_scale).clamp(min=-max_v, max=max_v).to(dtype) - return x, x_scale - +import torch.nn.functional as F from torchao.utils import TorchAOBaseTensor @@ -184,26 +178,6 @@ def _(func, types, args, kwargs): return -def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: - """ - Calculates padding for dense tensor and pads tensor if necessary. - If padding is not required, this function returns the original tensor. - """ - # only 2d matmul - assert dense_input.dim() == 2 - - # check shape - m, n = dense_input.shape - min_rows = 64 - min_cols = 64 - - # calculate padding - to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 - to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 - if to_pad_m or to_pad_n: - return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) - else: - return dense_input @implements(torch.nn.functional.linear) @@ -316,6 +290,7 @@ def _float8_dynamic_sparse_activation_float8_weight_quantize_tensor(weight, conf _layout=Float8Layout(mm_config=mm_config), ) + # input_quant_func = torch.compile(_input_activation_quant_func_fp8_sparse, fullgraph=True) input_quant_func = _input_activation_quant_func_fp8_sparse input_quant_kwargs = { "activation_granularity": activation_granularity, @@ -350,7 +325,10 @@ def _float8_dynamic_activation_sparse_float8_weight_transform( module.extra_repr = types.MethodType(_linear_extra_repr, module) return module +from collections import Counter +from pprint import pprint +SEEN = Counter() def _input_activation_quant_func_fp8_sparse( x: torch.Tensor, activation_granularity, @@ -361,7 +339,8 @@ def _input_activation_quant_func_fp8_sparse( """This function is used to quantize the input activation tensor for an aqt_float variant. If scale is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. """ - x_2d = x.view(-1, x.size(-1)) + # print(x.shape) + # x_2d = x.view(-1, x.size(-1)) assert zero_point is None, ( "Zero point is not supported for dynamic FP8 quantization" @@ -371,23 +350,29 @@ def _input_activation_quant_func_fp8_sparse( "PerRow quantization only works for bfloat16 precision input activation" ) - if ( - (x_2d.size(0) == 64) or - (x_2d.size(0) == 128) or - (x_2d.size(0) == 192) or - (x_2d.size(0) == 256) or - (x_2d.size(0) == 320) or - (x_2d.size(0) == 384) or - (x_2d.size(0) == 448) or - (x_2d.size(0) == 512) or - (x_2d.size(0) == 1024) or - (x_2d.size(0) == 2048) or - (x_2d.size(0) == 4096) or - (x_2d.size(0) == 8192) - ): - layout=CutlassSemiSparseLayout() - else: - layout=Float8Layout(mm_config=None) + # x_2d = _pad_dense_input(x_2d) + # if x.shape not in SEEN: + # SEEN[x.shape] += 1 + # pprint(SEEN) + # else: + # SEEN[x.shape] += 1 + + + # if ( + # (x.size(0) == 64) or + # (x.size(0) == 128) or + # (x.size(0) == 192) or + # (x.size(0) == 256) or + # (x.size(0) == 320) or + # (x.size(0) == 384) or + # (x.size(0) == 448) or + # (x.size(0) == 512) + # ): + # print(x.shape) + # if x.shape[0] % 64 == 0: + # else: + # layout=Float8Layout(mm_config=None) + layout=CutlassSemiSparseLayout() block_size = get_block_size(x.shape, activation_granularity) activation = to_affine_quantized_floatx( From 8180ee7ddb04bf5141d2a5a8da2396c45981a37f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:25:57 -0700 Subject: [PATCH 18/25] cleanup --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 12 +- test/sparsity/test_activation24.py | 123 ++---- torchao/dtypes/affine_quantized_tensor.py | 7 +- torchao/dtypes/affine_quantized_tensor_ops.py | 4 +- .../floatx/cutlass_semi_sparse_layout.py | 48 ++- .../sparsity/activation/srelu_linear.py | 29 +- .../sparsity/activation/float8dynamic_24.py | 169 ++++++++ .../activation/squared_relu_sparse.py | 385 ------------------ torchao/sparsity/sparse_api.py | 29 +- 9 files changed, 245 insertions(+), 561 deletions(-) create mode 100644 torchao/sparsity/activation/float8dynamic_24.py delete mode 100644 torchao/sparsity/activation/squared_relu_sparse.py diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index a9f643b5e3..3c9ac7a70c 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -9,20 +9,16 @@ from tqdm import tqdm from triton.testing import do_bench -from torchao.prototype.sparsity.activation.srelu_linear import ( - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, -) -from torchao.sparsity.sparse_api import ( - Float8DynamicSemiSparseActivationFloat8WeightConfig -) from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( - Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, PerRow, quantize_, ) +from torchao.sparsity.sparse_api import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, +) PROFILE = False @@ -94,6 +90,7 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): if PROFILE: print("PROFILING FP8") from torchao.prototype.sparsity.activation.utils import profiler_runner + inputs = (ffn_clone, input_tensor) profiler_runner(None, benchmark_microseconds, *inputs) @@ -127,6 +124,7 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): if PROFILE: print("PROFILING 24") from torchao.prototype.sparsity.activation.utils import profiler_runner + inputs = (ffn_clone, input_tensor) profiler_runner(None, benchmark_microseconds, *inputs) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index f4e12d8ebc..dbab3bf8bb 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -9,17 +9,18 @@ quantize_, ) from torchao.quantization.quant_api import _float8_cutlass_quant -from torchao.sparsity.activation.squared_relu_sparse import Float8DynamicSemiSparseActivationFloat8WeightConfig +from torchao.sparsity.sparse_api import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, +) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True import copy import unittest -from torchao.prototype.sparsity.activation.srelu_linear import ( - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig, -) -from torchao.sparsity import sparsify_ +from parameterized import parameterized + +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor from torchao.utils import is_sm_at_least_90 @@ -103,8 +104,18 @@ def test_sparse24_sm90_sparsify_srelu(M=512, K=1024, fp8=torch.float8_e4m3fn) -> assert (A_packed != A_packed_ref).float().mean().item() < 0.1 +@parameterized.expand( + [ + (1, 8192, 1024, True), + (64, 8192, 1024, True), + (1024, 8192, 1024, True), + (1, 8192, 1024, False), + (64, 8192, 1024, False), + (1024, 8192, 1024, False), + ] +) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): +def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): with torch.no_grad(): torch.manual_seed(0) input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() @@ -121,72 +132,27 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024): ), ) - # define reference implementation - def reference_srelu(x): - x = F.relu(x) ** 2 - return reference_linear(x) - - reference_srelu = torch.compile(reference_srelu, fullgraph=True) - - # this only works with fullgraph=True, errors in eager - # TODO figure out exactly why this happens - sparsify_( - reference_linear_copy, - SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(), - ) - # (reference_linear_copy) - reference_linear_copy.forward = torch.compile( - reference_linear_copy.forward, fullgraph=True - ) - - reference_output = reference_srelu(input_tensor) - custom_output = reference_linear_copy(input_tensor) - - print(reference_output) - print(custom_output) - - torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) - - -from torchao.sparsity.sparse_api import ActivationSparseLinearConfig -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_asdf(M=1, K=16384, N=4096): - with torch.no_grad(): - torch.manual_seed(0) - input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() - # we have to wrap in a sequential block for quantize_ to work properly - reference_linear = torch.nn.Sequential( - torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16) - ) - reference_linear_copy = copy.deepcopy(reference_linear) - - quantize_( - reference_linear, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) - ), - ) - # reference_linear.forward = torch.compile(reference_linear.forward) + if do_compile: + reference_linear.forward = torch.compile( + reference_linear.forward, + fullgraph=True, + ) - # this only works with fullgraph=True, errors in eager - # TODO figure out exactly why this happens quantize_( reference_linear_copy, Float8DynamicSemiSparseActivationFloat8WeightConfig( granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), ) - # (reference_linear_copy) - # reference_linear_copy.forward = torch.compile( - # reference_linear_copy.forward, - # ) + + if do_compile: + reference_linear_copy.forward = torch.compile( + reference_linear_copy.forward, fullgraph=True + ) reference_output = reference_linear(input_tensor) custom_output = reference_linear_copy(input_tensor) - print(reference_output.is_contiguous()) - print(custom_output.is_contiguous()) - torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) @@ -194,13 +160,13 @@ def test_asdf(M=1, K=16384, N=4096): def test_splitk_sparse_gemv(): torch.manual_seed(0) - activation = create_binary_tensor((1, 1, 4096), 0.2).cuda().to(torch.float16) + activation = create_binary_tensor((1, 4096), 0.2).cuda().to(torch.float16) weight = torch.randn(16384, 4096, dtype=torch.float16).cuda() # weight must be column major weight_transposed = weight.T.contiguous().T - sparse_res = torch.ops.torchao.splitk_sparse_gemv(activation, weight_transposed) + sparse_res = splitk_sparse_gemv(activation, weight_transposed) dense_res = F.linear(activation, weight_transposed) # This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output. @@ -234,7 +200,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( # Check MM with scale b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32) a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32) - A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale ) assert torch.allclose( @@ -270,32 +236,3 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype ) assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) - -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor_compile( - M=512, N=1024, K=256, dtype=torch.float8_e4m3fn -) -> None: - def _to_fp8_rowwise(x: torch.Tensor, dtype): - max_v = torch.finfo(dtype).max - x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() - x = (x / x_scale).to(dtype) - return x, x_scale - - torch.manual_seed(0) - A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() - A, a_scale = _to_fp8_rowwise(A_dense, dtype) - - B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) - B, b_scale = _to_fp8_rowwise(B_dense, dtype) - - B = B.T - b_scale = b_scale.T - - A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) - out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale - ) - out_ref = torch._scaled_mm( - A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype - ) - assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3b2569033f..00e808b01d 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -458,7 +458,10 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): - from torchao.dtypes.floatx.cutlass_semi_sparse_layout import CutlassSemiSparseLayout + from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( + CutlassSemiSparseLayout, + ) + """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: original_shape = input_float.shape @@ -478,8 +481,6 @@ def from_hp_to_floatx( dtype=input_float.dtype, ) - - scale = choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bfafc352d0..14ce697f6e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -14,7 +14,7 @@ from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, - _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_check, _linear_fp8_act_sparse_fp8_weight_cutlass_impl, ) from torchao.dtypes.floatx.float8_layout import ( @@ -194,7 +194,7 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int8_weight_semi_structured_sparse_impl, ), ( - _linear_fp8_act_sparse_fp8_weight_cutlass_check, + _linear_fp8_act_sparse_fp8_weight_cutlass_check, _linear_fp8_act_sparse_fp8_weight_cutlass_impl, ), ( diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 3132c6774c..2449433078 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -18,11 +18,11 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.ops import ( rowwise_scaled_linear_sparse_cutlass_f8f8, - to_sparse_semi_structured_cutlass_sm9x_f8, ) aten = torch.ops.aten + def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: """ Calculates padding for dense tensor and pads tensor if necessary. @@ -32,18 +32,16 @@ def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: assert dense_input.dim() == 2 # check shape - m, n = dense_input.shape + m, n = dense_input.size() min_rows = 64 min_cols = 64 # calculate padding - to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 - to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 - if to_pad_m or to_pad_n: - return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) - else: - return dense_input - + to_pad_m = -m % min_rows + to_pad_n = -n % min_cols + return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) + + def _pad_scale(scale: torch.Tensor) -> torch.Tensor: """ Calculates padding for dense tensor and pads tensor if necessary. @@ -53,18 +51,14 @@ def _pad_scale(scale: torch.Tensor) -> torch.Tensor: assert scale.dim() == 2 # check shape - m, n = scale.shape + m, n = scale.size() assert n == 1 min_rows = 64 - # min_cols = 64 # calculate padding - to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 - # to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 - if to_pad_m: - return torch.nn.functional.pad(scale, (0, 0, 0, to_pad_m)) - else: - return scale + to_pad_m = -m % min_rows + return torch.nn.functional.pad(scale, (0, 0, 0, to_pad_m)) + def _same_metadata( self: "CutlassSemiSparseTensorImpl", src: "CutlassSemiSparseTensorImpl" @@ -85,10 +79,10 @@ class CutlassSemiSparseLayout(Layout): """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" # def pre_process(self, dense: torch.Tensor) -> torch.Tensor: - # # prune to 2:4 if not already - # from torchao.sparsity.utils import mask_creator + # # prune to 2:4 if not already + # from torchao.sparsity.utils import mask_creator - # return dense * mask_creator(dense).bool() + # return dense * mask_creator(dense).bool() @register_layout(CutlassSemiSparseLayout) @@ -275,6 +269,7 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, return out + def _linear_fp8_act_sparse_fp8_weight_cutlass_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import Float8Layout @@ -296,9 +291,8 @@ def _linear_fp8_act_sparse_fp8_weight_cutlass_check(input_tensor, weight_tensor, ) return res -def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 +def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, bias): input_sparse = input_tensor.tensor_impl.sparse input_meta = input_tensor.tensor_impl.meta input_scale = input_tensor.tensor_impl.scale @@ -306,10 +300,14 @@ def _linear_fp8_act_sparse_fp8_weight_cutlass_impl(input_tensor, weight_tensor, weight_scale = weight_tensor.tensor_impl.scale out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) - rows, cols = (input_tensor.shape) + rows, cols = input_tensor.shape out = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - input_sparse, input_meta, weight.t(), a_scale=input_scale, b_scale=weight_scale.t(), + input_sparse, + input_meta, + weight.t(), + a_scale=input_scale, + b_scale=weight_scale.t(), )[:rows, :].view(out_shape) - + return out diff --git a/torchao/prototype/sparsity/activation/srelu_linear.py b/torchao/prototype/sparsity/activation/srelu_linear.py index 3de8edb15b..f8c3288b67 100644 --- a/torchao/prototype/sparsity/activation/srelu_linear.py +++ b/torchao/prototype/sparsity/activation/srelu_linear.py @@ -49,40 +49,29 @@ def __init__(self, weight, config) -> None: self.config = config W_aqt = _float8_cutlass_quant(weight, self.config.weight_dtype) - self.W = W_aqt.tensor_impl.float8_data + self.Wq = W_aqt.tensor_impl.float8_data self.W_scale = W_aqt.tensor_impl.scale def forward(self, x): - # breakpoint() - # print(x) X_scale = torch.empty([x.shape[0], 1], device=x.device, dtype=torch.float32) Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( x, "cutlass", - "identity", + "srelu", "largest", dtype=self.config.activation_dtype, scale=X_scale, ) - # result = rowwise_scaled_linear_sparse_cutlass_f8f8( - # self.W, - # self.W_scale.squeeze(), - # Xq_sparse, - # X_meta, - # X_scale.squeeze(), - # bias=None, - # out_dtype=torch.bfloat16, - # ).t() - - # result = - result = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + result = rowwise_scaled_linear_sparse_cutlass_f8f8( + self.Wq, + self.W_scale, Xq_sparse, X_meta, - self.W.t(), - a_scale=X_scale, - b_scale=self.W_scale.t(), - ) + X_scale, + bias=None, + out_dtype=torch.bfloat16, + ).t() return result diff --git a/torchao/sparsity/activation/float8dynamic_24.py b/torchao/sparsity/activation/float8dynamic_24.py new file mode 100644 index 0000000000..e6e1f69f2d --- /dev/null +++ b/torchao/sparsity/activation/float8dynamic_24.py @@ -0,0 +1,169 @@ +import types +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch + +import torchao +from torchao.core.config import AOBaseConfig +from torchao.dtypes import ( + CutlassSemiSparseLayout, + Float8Layout, + to_affine_quantized_floatx, +) +from torchao.float8.config import e4m3_dtype +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _check_hardware_support, + _normalize_granularity, +) +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_api import ( + Float8Layout, + PerRow, + _check_hardware_support, + _fp8_mm_compat, + _linear_extra_repr, + to_affine_quantized_floatx, + to_linear_activation_quantized, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.utils import ( + is_MI300, + is_sm_at_least_89, +) + + +@dataclass +class Float8DynamicSemiSparseActivationFloat8WeightConfig(AOBaseConfig): + """ + Configuration for applying float8 dynamic symmetric quantization + 2:4 sparsity to the activations and float8 dynamic quantization to the weights + + Args: + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. + granularity: + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerRowfor both. Currently both quantizations need to be the same type. And + only PerRow is currently supported. + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + + """ + + activation_dtype: torch.dtype = e4m3_dtype + weight_dtype: torch.dtype = e4m3_dtype + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None + mm_config: Optional[Float8MMConfig] = None + set_inductor_config: bool = True + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + activation_granularity, weight_granularity = _normalize_granularity( + self.granularity + ) + self.granularity = [activation_granularity, weight_granularity] + + +def _float8_dynamic_sparse_activation_float8_weight_quantize_tensor(weight, config): + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + # Ensure works on device + _check_hardware_support(granularity) + activation_granularity, weight_granularity = granularity + + if not _fp8_mm_compat(weight): + return weight + + if isinstance(weight_granularity, PerRow): + assert weight.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input weight" + ) + + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) + + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + # use sparsify function here instead of default fp8 quant func + input_quant_func = _input_activation_quant_func_fp8_sparse + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + return quantized_weight + + +@register_quantize_module_handler(Float8DynamicSemiSparseActivationFloat8WeightConfig) +def _float8_dynamic_activation_sparse_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicSemiSparseActivationFloat8WeightConfig +): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying float8 dynamic activation quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + quantized_weight = _float8_dynamic_sparse_activation_float8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + +def _input_activation_quant_func_fp8_sparse( + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: Optional[torch.Tensor] = None, + zero_point: Optional[torch.Tensor] = None, +): + """This function is used to quantize + sparsify the input activation tensor for an aqt_float variant. If scale + is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. + """ + assert zero_point is None, ( + "Zero point is not supported for dynamic FP8 quantization" + ) + + assert isinstance(activation_granularity, PerRow), ( + "Only PerRow quantization is currently supported" + ) + assert x.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input activation" + ) + + block_size = get_block_size(x.shape, activation_granularity) + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + # we change the sparsification routine via Layout + _layout=CutlassSemiSparseLayout(), + ) + return activation diff --git a/torchao/sparsity/activation/squared_relu_sparse.py b/torchao/sparsity/activation/squared_relu_sparse.py deleted file mode 100644 index 7424a75514..0000000000 --- a/torchao/sparsity/activation/squared_relu_sparse.py +++ /dev/null @@ -1,385 +0,0 @@ -import types -from dataclasses import dataclass -from typing import List, Optional, Union - -import torch - -import torchao -from torchao.core.config import AOBaseConfig -from torchao.dtypes import ( - CutlassSemiSparseLayout, - Float8Layout, - to_affine_quantized_floatx, -) -from torchao.float8.config import e4m3_dtype -from torchao.float8.inference import ( - Float8MMConfig, - FP8Granularity, - _check_hardware_support, - _normalize_granularity, -) -from torchao.quantization.observer import get_block_size -from torchao.quantization.quant_api import ( - PerRow, - _float8_cutlass_quant, - _linear_extra_repr, - to_linear_activation_quantized, -) -from torchao.quantization.transform_module import ( - register_quantize_module_handler, -) -from torchao.utils import ( - is_MI300, - is_sm_at_least_89, -) - -import torch.nn.functional as F - -from torchao.utils import TorchAOBaseTensor - - -@dataclass -class ActivationSparseLinearConfig(AOBaseConfig): - """ - Adds in acceleration for activation sparsity to linear layers for decode. - - Args: - `activation_dtype`: data type for quantized activation tensor. - `weight_dtype`: data type for quantized weight tensor. - """ - - activation_dtype: torch.dtype = torch.float8_e4m3fn - weight_dtype: torch.dtype = torch.float8_e4m3fn - - mm_config = Float8MMConfig(use_fast_accum=True) - - -@register_quantize_module_handler(ActivationSparseLinearConfig) -def _( - module: torch.nn.Module, - config: ActivationSparseLinearConfig, -): - new_weight = ActivationSparseTensor.from_dense(module.weight.data) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - - -class ActivationSparseTensor(TorchAOBaseTensor): - data: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - - __slots__ = ["data", "scale"] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - data: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - requires_grad: bool = False, - ): - assert data is not None - kwargs = { - "device": data.device, - "dtype": data.dtype, - "layout": data.layout, - "requires_grad": requires_grad, - } - tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - tensor.data = data - tensor.scale = scale - return tensor - - def __repr__(self) -> str: # type: ignore[override] - assert hasattr(self, "shape") - return f"{self.__class__.__name__}(shape={self.shape})" - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta, - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, requires_grad = tensor_meta - return cls( - shape=shape, - data=inner_tensors.get("data", None), - scale=inner_tensors.get("scale", None), - requires_grad=requires_grad, - ) - - @classmethod - def from_dense(cls, weight, use_fp8=True): - if use_fp8: - W_aqt = _float8_cutlass_quant(weight, torch.float8_e4m3fn) - W = W_aqt.tensor_impl.float8_data - W_scale = W_aqt.tensor_impl.scale - return cls(weight.shape, data=W, scale=W_scale, requires_grad=False) - else: - return cls( - weight.shape, - data=weight.data.t().contiguous().t(), - scale=None, - requires_grad=False, - ) - - def apply_fn_to_shard(self, func): - return ActivationSparseTensor( - shape=self.shape, - data=func(self.data), - scale=func(self.scale), - requires_grad=self.requires_grad, - ) - - -# Subclass op dispatch registration -implements = ActivationSparseTensor.implements -aten = torch.ops.aten - - -@implements( - [ - aten.detach.default, - aten.slice.Tensor, - ] -) -def _(func, types, args, kwargs): - new_data = func(args[0].data, *args[1:], **kwargs) - if args[0].scale is None: - new_scale = None - else: - new_scale = func(args[0].scale, *args[1:], **kwargs) - return ActivationSparseTensor( - new_data.shape, - data=new_data, - scale=new_scale, - requires_grad=False, - ) - - -@implements([aten.copy_.default]) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if not isinstance(src, ActivationSparseTensor): - src_subclass = ActivationSparseTensor.from_dense(src) - self.data.copy_(src_subclass.data) - self.scale.copy_(src_subclass.scale) - return - - - - -@implements(torch.nn.functional.linear) -def sparse_activation_linear(func, types, args, kwargs): - x_orig, w, bias = args - assert bias is None - x = x_orig.view(-1, x_orig.size(-1)) - m, n = x.shape - - # # # if x input is the right shape, we use sparse matmul - # x_padded = _pad_dense_input(x) - # if (x.size(0) % 64) == 0: - # if (x.size(0) == 64) or (x.size(0) == 128) or (x.size(0) ==256) or (x.size(0)==512): - if False: - X_scale = torch.empty( - [x.shape[0], 1], dtype=torch.float32, device=x_orig.device - ) - Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( - x, - "cutlass", - "identity", - "largest", - dtype=torch.float8_e4m3fn, - scale=X_scale, - ) - - out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( - Xq_sparse, - X_meta, - w.data.t(), - a_scale=X_scale, - b_scale=w.scale.t(), - ) - # print(out_sparse.shape) - out_sparse = out_sparse.reshape(*x_orig.shape[:-1], w.shape[0]) - return out_sparse - else: - w_dequantized = (w.data.to(torch.float32) * w.scale).to(torch.bfloat16) - return torch.nn.functional.linear(x_orig, w_dequantized, bias) - - -from torchao.quantization.quant_api import ( - Float8Layout, - _check_hardware_support, - _fp8_mm_compat, - to_affine_quantized_floatx, -) - - -@dataclass -class Float8DynamicSemiSparseActivationFloat8WeightConfig(AOBaseConfig): - """ - Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. - - Args: - activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. - weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. - granularity: - The granularity for quantization. Can be either a single granularity (applied to both - activations and weights) or a tuple of two granularities (one for activations, one for weights). - If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And - only PerTensor and PerRow are supported. - mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. - set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - - """ - - activation_dtype: torch.dtype = e4m3_dtype - weight_dtype: torch.dtype = e4m3_dtype - granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None - mm_config: Optional[Float8MMConfig] = None - set_inductor_config: bool = True - - def __post_init__(self): - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) - - activation_granularity, weight_granularity = _normalize_granularity( - self.granularity - ) - self.granularity = [activation_granularity, weight_granularity] - - -def _float8_dynamic_sparse_activation_float8_weight_quantize_tensor(weight, config): - activation_dtype = config.activation_dtype - weight_dtype = config.weight_dtype - granularity = config.granularity - mm_config = config.mm_config - - # Ensure works on device - _check_hardware_support(granularity) - activation_granularity, weight_granularity = granularity - - if not _fp8_mm_compat(weight): - # TODO(future PR): this should really throw an exception instead of silently - # not doing what the user asked - return weight - if isinstance(weight_granularity, PerRow): - assert weight.dtype == torch.bfloat16, ( - "PerRow quantization only works for bfloat16 precision input weight" - ) - block_size = get_block_size(weight.shape[-2:], weight_granularity) - if weight.dim() == 3: - block_size = tuple([1] + list(block_size)) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) - - # input_quant_func = torch.compile(_input_activation_quant_func_fp8_sparse, fullgraph=True) - input_quant_func = _input_activation_quant_func_fp8_sparse - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) - return quantized_weight - - -@register_quantize_module_handler(Float8DynamicSemiSparseActivationFloat8WeightConfig) -def _float8_dynamic_activation_sparse_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicSemiSparseActivationFloat8WeightConfig -): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - - assert hasattr(module, "weight"), ( - "applying float8 dynamic activation quant requires module to have weight attribute" - + f"but {module} does not have one" - ) - quantized_weight = _float8_dynamic_sparse_activation_float8_weight_quantize_tensor( - module.weight, config - ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - -from collections import Counter -from pprint import pprint - -SEEN = Counter() -def _input_activation_quant_func_fp8_sparse( - x: torch.Tensor, - activation_granularity, - activation_dtype: torch.dtype, - scale: Optional[torch.Tensor] = None, - zero_point: Optional[torch.Tensor] = None, -): - """This function is used to quantize the input activation tensor for an aqt_float variant. If scale - is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. - """ - # print(x.shape) - # x_2d = x.view(-1, x.size(-1)) - - assert zero_point is None, ( - "Zero point is not supported for dynamic FP8 quantization" - ) - if isinstance(activation_granularity, PerRow): - assert x.dtype == torch.bfloat16, ( - "PerRow quantization only works for bfloat16 precision input activation" - ) - - # x_2d = _pad_dense_input(x_2d) - # if x.shape not in SEEN: - # SEEN[x.shape] += 1 - # pprint(SEEN) - # else: - # SEEN[x.shape] += 1 - - - # if ( - # (x.size(0) == 64) or - # (x.size(0) == 128) or - # (x.size(0) == 192) or - # (x.size(0) == 256) or - # (x.size(0) == 320) or - # (x.size(0) == 384) or - # (x.size(0) == 448) or - # (x.size(0) == 512) - # ): - # print(x.shape) - # if x.shape[0] % 64 == 0: - # else: - # layout=Float8Layout(mm_config=None) - layout=CutlassSemiSparseLayout() - - block_size = get_block_size(x.shape, activation_granularity) - activation = to_affine_quantized_floatx( - input_float=x, - block_size=block_size, - target_dtype=activation_dtype, - scale_dtype=torch.float32, - _layout=layout, - ) - return activation diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 0fcf04fdac..0dc4ff7e87 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -11,7 +11,6 @@ from torch.sparse import to_sparse_semi_structured from torchao.core.config import AOBaseConfig -from torchao.float8.inference import Float8MMConfig from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( WeightNormSparsifier, ) @@ -24,25 +23,10 @@ _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) -from torchao.sparsity.blocksparse import BlockSparseTensor -from dataclasses import dataclass - -import torch -from torch import nn - -from torchao.core.config import AOBaseConfig -from torchao.ops import ( - rowwise_scaled_linear_sparse_cutlass_f8f8, -) -from torchao.quantization.quant_api import ( - _float8_cutlass_quant, +from torchao.sparsity.activation.float8dynamic_24 import ( + Float8DynamicSemiSparseActivationFloat8WeightConfig, # noqa: F401 ) -from torchao.quantization.transform_module import ( - register_quantize_module_handler, -) - -from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv -from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.sparsity.blocksparse import BlockSparseTensor # Sparsity helper functions @@ -153,10 +137,3 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _is_linear if filter_fn is None else filter_fn, extra_args=(config,), ) - - - -from torchao.sparsity.activation.squared_relu_sparse import ( - ActivationSparseLinearConfig, - Float8DynamicSemiSparseActivationFloat8WeightConfig, -) From b41ee442bf99ccffc7dd855b2695722c6726c7a3 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:27:02 -0700 Subject: [PATCH 19/25] reset --- torchao/csrc/cuda/activation24/sparse_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 7ed7702c7e..e37c466119 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -113,7 +113,7 @@ struct SparseRowwiseKernel { cutlass::layout::ColumnMajor, 16, ElementAccumulator, - cute::Shape, + cute::Shape, cute::Shape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, From e2593c4782d60b9a8b5852ffa4c79ee0e4d1011e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:29:12 -0700 Subject: [PATCH 20/25] remove unneeded changes --- torchao/ops.py | 32 +++++++++++++++--------- torchao/quantization/quant_api.py | 4 --- torchao/quantization/quant_primitives.py | 2 -- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index a44779f2ec..b91bb8ae18 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -843,18 +843,26 @@ def sparse24_sm90_sparsify( ) -# @register_custom_op("torchao::sparse24_fp8_sm90_cutlass_gemm") -# def _( -# a: Tensor, -# meta: Tensor, -# b: Tensor, -# a_scale: Optional[Tensor], -# b_scale: Optional[Tensor], -# swizzle_size: int = 8, -# swizzle_axis: str = 'n', -# sm_count: int = 128, -# ) -> Tensor: -# return torch.empty(a.shape[0], b.shape[1], dtype=torch.bfloat16, device=a.device) +def sparse24_fp8_sm90_cutlass_gemm( + a: Tensor, + meta: Tensor, + b: Tensor, + a_scale: Optional[Tensor], + b_scale: Optional[Tensor], + swizzle_size: int, + swizzle_axis: str, + sm_count: int, +) -> Tensor: + return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + a, + meta, + b, + a_scale=a_scale, + b_scale=b_scale, + swizzle_size=swizzle_size, + swizzle_axis=swizzle_axis, + sm_count=sm_count, + ) def swizzle_mm( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 39223e6e50..f2aca97782 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,7 +75,6 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, - TorchAOBaseTensor, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -538,9 +537,6 @@ def _quantization_type(weight: torch.Tensor): if type(weight) is torch.Tensor: return "not quantized" - if isinstance(weight, TorchAOBaseTensor): - return f"{weight.__class__.__name__}" - return "not recognized" diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 446f512c65..cee8df21a2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -674,8 +674,6 @@ def _dequantize_affine_no_dtype_check( 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain 3. reshape the quantized result to origianl shape and change dtype to the output_dtype """ - if len(block_size) != input.dim(): - breakpoint() assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" ) From f2dab64b996b8b4ff2cd27594b9acf08c562a379 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:41:16 -0700 Subject: [PATCH 21/25] more cleanup --- torchao/csrc/cuda/activation24/sparse_gemm.cu | 4 +-- torchao/dtypes/affine_quantized_tensor.py | 31 +++++++------------ .../floatx/cutlass_semi_sparse_layout.py | 7 ++--- .../sparsity/activation/float8dynamic_24.py | 3 -- 4 files changed, 16 insertions(+), 29 deletions(-) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index e37c466119..452daadb5b 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -113,7 +113,7 @@ struct SparseRowwiseKernel { cutlass::layout::ColumnMajor, 16, ElementAccumulator, - cute::Shape, + cute::Shape, cute::Shape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, @@ -193,7 +193,7 @@ struct SparseRowwiseKernel { cutlass::layout::ColumnMajor, 16, float, - cute::Shape, + cute::Shape, cute::Shape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 00e808b01d..050773b4fa 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -458,33 +458,24 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): - from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( - CutlassSemiSparseLayout, - ) - """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) - - # handle CUTLASS specially - if isinstance(_layout, CutlassSemiSparseLayout): - scale = choose_qparams_affine_float8( - input_float, float8_dtype=target_dtype, block_size=block_size - ) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(input_float, scale, None, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - dtype=input_float.dtype, - ) - scale = choose_qparams_affine_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) - data = quantize_affine_float8(input_float, scale, target_dtype) + + # need to import here to avoid circular import + from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ( + CutlassSemiSparseLayout, + ) + + if isinstance(_layout, CutlassSemiSparseLayout): + # handle sparse activation specially, since the sparsification kernel also does the quantization + data = input_float + else: + data = quantize_affine_float8(input_float, scale, target_dtype) data, scale, zero_point = _layout.post_process( data, scale, None, block_size ) diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 2449433078..9727701bd0 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -203,8 +203,7 @@ def from_plain( dense_padded = _pad_dense_input(dense) scale_padded = _pad_scale(scale) - # X_scale = torch.empty((dense.shape[0], 1), device=dense.device, dtype=torch.float32) - Xq_sparse, X_meta = torch.ops.torchao.sparse24_sm90_sparsify( + sparse, meta = torch.ops.torchao.sparse24_sm90_sparsify( dense_padded, "cutlass", "identity", @@ -215,8 +214,8 @@ def from_plain( res = cls( dense.shape, - Xq_sparse, - X_meta, + sparse, + meta, scale_padded, _layout, ) diff --git a/torchao/sparsity/activation/float8dynamic_24.py b/torchao/sparsity/activation/float8dynamic_24.py index e6e1f69f2d..1dbebccad3 100644 --- a/torchao/sparsity/activation/float8dynamic_24.py +++ b/torchao/sparsity/activation/float8dynamic_24.py @@ -9,7 +9,6 @@ from torchao.dtypes import ( CutlassSemiSparseLayout, Float8Layout, - to_affine_quantized_floatx, ) from torchao.float8.config import e4m3_dtype from torchao.float8.inference import ( @@ -20,9 +19,7 @@ ) from torchao.quantization.observer import get_block_size from torchao.quantization.quant_api import ( - Float8Layout, PerRow, - _check_hardware_support, _fp8_mm_compat, _linear_extra_repr, to_affine_quantized_floatx, From 6102a49ec03429723c3220dd5399a8485e678894 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:56:06 -0700 Subject: [PATCH 22/25] ruff passing --- benchmarks/benchmark_e2e_fp8_sparse_linear.py | 37 +------------------ benchmarks/benchmark_splitk_sparse_gemv.py | 11 +----- .../floatx/cutlass_semi_sparse_layout.py | 10 ----- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/benchmarks/benchmark_e2e_fp8_sparse_linear.py b/benchmarks/benchmark_e2e_fp8_sparse_linear.py index 3c9ac7a70c..a72e5ea449 100644 --- a/benchmarks/benchmark_e2e_fp8_sparse_linear.py +++ b/benchmarks/benchmark_e2e_fp8_sparse_linear.py @@ -20,8 +20,6 @@ Float8DynamicSemiSparseActivationFloat8WeightConfig, ) -PROFILE = False - def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 @@ -73,33 +71,12 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor) - # fp8 sparse - # ffn_clone = ( - # nn.Sequential( - # nn.Linear(hidden_size, intermediate_size, bias=False), - # SquaredReLU(), - # nn.Linear(intermediate_size, hidden_size, bias=False), - # ) - # .to(torch.bfloat16) - # .cuda() - # ) - # quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig()) - # ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) - # fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) - - if PROFILE: - print("PROFILING FP8") - from torchao.prototype.sparsity.activation.utils import profiler_runner - - inputs = (ffn_clone, input_tensor) - profiler_runner(None, benchmark_microseconds, *inputs) - # activation fp8 sparse ffn_clone = ( nn.Sequential( nn.Linear(hidden_size, intermediate_size, bias=False), # no Squared RELU since it will be fused into the second linear - # SquaredReLU(), + SquaredReLU(), nn.Linear(intermediate_size, hidden_size, bias=False), ) .to(torch.bfloat16) @@ -112,28 +89,19 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): ), ) quantize_( - ffn_clone[1], + ffn_clone[2], Float8DynamicSemiSparseActivationFloat8WeightConfig( granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) ), - # filter_fn=lambda mod, fqn: "1" in fqn, ) ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True) fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor) - if PROFILE: - print("PROFILING 24") - from torchao.prototype.sparsity.activation.utils import profiler_runner - - inputs = (ffn_clone, input_tensor) - profiler_runner(None, benchmark_microseconds, *inputs) - return { "num_tokens": num_tokens, "bf16_latency (us)": fp16_time, "bf16_c_latency (us)": fp16_c_time, "fp8_c_time (us)": fp8_c_time, - # "fp8_c_sparse_time (us)": fp8_c_sparse_time, "fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time, "speedup": fp8_c_time / fp8_c_activation_sparse_time, } @@ -142,7 +110,6 @@ def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384): if __name__ == "__main__": with torch.no_grad(): results = [] - # for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]): for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]): results.append(benchmark(num_tokens)) torch.compiler.reset() diff --git a/benchmarks/benchmark_splitk_sparse_gemv.py b/benchmarks/benchmark_splitk_sparse_gemv.py index de895ca962..9623ce5d59 100644 --- a/benchmarks/benchmark_splitk_sparse_gemv.py +++ b/benchmarks/benchmark_splitk_sparse_gemv.py @@ -5,25 +5,18 @@ from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor -dtype = torch.float8_e4m3fn +dtype = torch.bfloat16 for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]: a = create_binary_tensor((1, 4096), sparsity_level).cuda().to(dtype) b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T - sparse_time = ( - do_bench(lambda: splitk_sparse_gemv(a, b, out_dtype=torch.bfloat16)) * 1e6 - ) + sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6 dense_time = ( do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6 ) - # b = torch.randn(4096, 16384).cuda().to(dtype).T.contiguous().T - # dense_time = do_bench(lambda: torch._scaled_mm(a.squeeze(0), b, - # scale_a=torch.Tensor([1]).cuda(), - # scale_b=torch.Tensor([1]).cuda(), - # out_dtype=torch.bfloat16)) * 1e6 speedup = dense_time / sparse_time print( f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}" diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 9727701bd0..e298aec77d 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -78,12 +78,6 @@ def _same_metadata( class CutlassSemiSparseLayout(Layout): """Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel.""" - # def pre_process(self, dense: torch.Tensor) -> torch.Tensor: - # # prune to 2:4 if not already - # from torchao.sparsity.utils import mask_creator - - # return dense * mask_creator(dense).bool() - @register_layout(CutlassSemiSparseLayout) class CutlassSemiSparseTensorImpl(AQTTensorImpl): @@ -161,8 +155,6 @@ def get_plain(self): # No support in CUTLASS to convert back to dense from sparse # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. - # breakpoint() - # raise NotImplementedError("get_plain not supported for CutlassSemiSparseTensorImpl") cols = self.shape[-1] input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) input_scale = torch.ones( @@ -195,8 +187,6 @@ def from_plain( _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) - # print(dense.shape) - # dense_2d = dense.view(-1, dense.shape[-1]) assert dense.ndim == 2 assert dense.is_contiguous() From ec8d5a99c082c43934b192f1d57949be7b7a9750 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 29 May 2025 23:58:31 -0700 Subject: [PATCH 23/25] rename file --- test/sparsity/test_activation24.py => test_activation.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/sparsity/test_activation24.py => test_activation.py (100%) diff --git a/test/sparsity/test_activation24.py b/test_activation.py similarity index 100% rename from test/sparsity/test_activation24.py rename to test_activation.py From d41eda23b66df2453736e04fbe48103c1da6b094 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 30 May 2025 01:02:26 -0700 Subject: [PATCH 24/25] move test to test dir --- .../sparsity/test_activation.py | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) rename test_activation.py => test/sparsity/test_activation.py (80%) diff --git a/test_activation.py b/test/sparsity/test_activation.py similarity index 80% rename from test_activation.py rename to test/sparsity/test_activation.py index dbab3bf8bb..e4cba46e7e 100644 --- a/test_activation.py +++ b/test/sparsity/test_activation.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F +from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ActivationFunction from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 +from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -13,8 +15,6 @@ Float8DynamicSemiSparseActivationFloat8WeightConfig, ) -torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True - import copy import unittest @@ -23,6 +23,7 @@ from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor from torchao.utils import is_sm_at_least_90 +from torchao.sparsity import ActivationFunction @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") @@ -155,6 +156,61 @@ def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) +@parameterized.expand( + [ + # (1, 8192, 1024, True), + # (64, 8192, 1024, True), + # (1024, 8192, 1024, True), + # (1, 8192, 1024, False), + (64, 8192, 1024, False), + # (1024, 8192, 1024, False), + ] +) +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_srelu_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): + with torch.no_grad(): + torch.manual_seed(0) + input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + # we have to wrap in a sequential block for quantize_ to work properly + reference_linear = torch.nn.Sequential( + SquaredReLU(), + torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16) + ) + reference_linear_copy = copy.deepcopy(reference_linear[1]) + print(reference_linear_copy) + + quantize_( + reference_linear, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), + ) + + if do_compile: + reference_linear.forward = torch.compile( + reference_linear.forward, + fullgraph=True, + ) + + quantize_( + reference_linear_copy, + Float8DynamicSemiSparseActivationFloat8WeightConfig( + activation_fn=ActivationFunction.SQUARED_RELU, + granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) + ), + ) + print(reference_linear_copy) + + if do_compile: + reference_linear_copy.forward = torch.compile( + reference_linear_copy.forward, fullgraph=True + ) + + reference_output = reference_linear(input_tensor) + custom_output = reference_linear_copy(input_tensor) + + torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) + @unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run") def test_splitk_sparse_gemv(): From 61aedfda446b013e1f63933c7d6524ed3fbaffc5 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 30 May 2025 07:51:04 -0700 Subject: [PATCH 25/25] update --- test/sparsity/test_activation.py | 70 +++----------------------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/test/sparsity/test_activation.py b/test/sparsity/test_activation.py index e4cba46e7e..b592e2a888 100644 --- a/test/sparsity/test_activation.py +++ b/test/sparsity/test_activation.py @@ -1,9 +1,12 @@ +import copy +import unittest + import torch import torch.nn.functional as F +from parameterized import parameterized -from torchao.dtypes.floatx.cutlass_semi_sparse_layout import ActivationFunction +from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 -from torchao.prototype.sparsity.activation.utils import SquaredReLU from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -14,16 +17,8 @@ from torchao.sparsity.sparse_api import ( Float8DynamicSemiSparseActivationFloat8WeightConfig, ) - -import copy -import unittest - -from parameterized import parameterized - -from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor from torchao.utils import is_sm_at_least_90 -from torchao.sparsity import ActivationFunction @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") @@ -156,61 +151,6 @@ def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) -@parameterized.expand( - [ - # (1, 8192, 1024, True), - # (64, 8192, 1024, True), - # (1024, 8192, 1024, True), - # (1, 8192, 1024, False), - (64, 8192, 1024, False), - # (1024, 8192, 1024, False), - ] -) -@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") -def test_srelu_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False): - with torch.no_grad(): - torch.manual_seed(0) - input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() - # we have to wrap in a sequential block for quantize_ to work properly - reference_linear = torch.nn.Sequential( - SquaredReLU(), - torch.nn.Linear(K, N, bias=False).cuda().to(torch.bfloat16) - ) - reference_linear_copy = copy.deepcopy(reference_linear[1]) - print(reference_linear_copy) - - quantize_( - reference_linear, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) - ), - ) - - if do_compile: - reference_linear.forward = torch.compile( - reference_linear.forward, - fullgraph=True, - ) - - quantize_( - reference_linear_copy, - Float8DynamicSemiSparseActivationFloat8WeightConfig( - activation_fn=ActivationFunction.SQUARED_RELU, - granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True) - ), - ) - print(reference_linear_copy) - - if do_compile: - reference_linear_copy.forward = torch.compile( - reference_linear_copy.forward, fullgraph=True - ) - - reference_output = reference_linear(input_tensor) - custom_output = reference_linear_copy(input_tensor) - - torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) - @unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run") def test_splitk_sparse_gemv():