Skip to content

Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs #2592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
from torchao.prototype.blockwise_fp8.kernels import (
blockwise_fp8_gemm,
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
Expand Down
169 changes: 169 additions & 0 deletions benchmarks/float8/bench_fp8_blockwise_quant_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
import argparse
import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import benchmark_microseconds

from torchao.prototype.blockwise_fp8.kernels import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
torch_blockwise_scale_act_quant,
torch_blockwise_scale_weight_quant,
triton_quantize_fp8_block,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
A_shape: tuple[int]
block_m: int
block_k: int


@dataclass(frozen=True)
class ExperimentResult:
torch_us: float
fbgemm_us: float
deepgemm_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
A_shapes = [
(1024, 1024),
(2048, 2048),
(4096, 4096),
(8192, 8192),
(16384, 16384),
(32768, 32768),
]
block_m_opts = [1, 128]
block_k_opts = [
128,
]
configs = []
for A_shape, block_m, block_k in itertools.product(
A_shapes,
block_m_opts,
block_k_opts,
):
configs.append(
ExperimentConfig(
A_shape=A_shape,
block_m=block_m,
block_k=block_k,
)
)
return configs


def run_experiment(
config: ExperimentConfig, args: argparse.Namespace
) -> ExperimentResult:
A = torch.randn(
*config.A_shape,
dtype=torch.bfloat16,
device=device,
)

# Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size)
# and weight quantization (block_size x block_size)
if config.block_m == 1:
torch_func = torch.compile(torch_blockwise_scale_act_quant)
deepgemm_func = fp8_blockwise_act_quant
else:
torch_func = torch.compile(torch_blockwise_scale_weight_quant)
deepgemm_func = fp8_blockwise_weight_quant

# Validate output shapes and strides
torch_out, torch_scale = torch_func(A, tile_size=config.block_k)
deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k)
fbgemm_out, fbgemm_scale = triton_quantize_fp8_block(
A, block_m=config.block_m, block_k=config.block_k, k_major=True
)
assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape
assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride()
assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape
assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride()

# Do benchmarking
torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k)
deepgemm_us = benchmark_microseconds(
fp8_blockwise_act_quant, A, block_size=config.block_k
)
fbgemm_us = benchmark_microseconds(
triton_quantize_fp8_block,
A,
block_m=config.block_m,
block_k=config.block_k,
k_major=True,
)

return ExperimentResult(
torch_us=round(torch_us, 3),
fbgemm_us=round(fbgemm_us, 3),
deepgemm_us=round(deepgemm_us, 3),
)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"block_shape",
"torch_us",
"fbgemm_us",
"deepgemm_us",
]
rows = []
for experiment in experiments:
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
block_shape = f"({experiment.config.block_m},{experiment.config.block_k})"
rows.append(
[
A_shape,
block_shape,
experiment.result.torch_us,
experiment.result.fbgemm_us,
experiment.result.deepgemm_us,
]
)
print(tabulate(rows, headers=headers))


def main(args: argparse.Namespace):
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config, args)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--compile", action="store_true")
args = arg_parser.parse_args()
main(args)
10 changes: 10 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch.utils.benchmark as benchmark
from torch.profiler import ProfilerActivity, profile
from triton.testing import do_bench


def profiler_output_to_filtered_time_by_kernel_name(
Expand Down Expand Up @@ -428,3 +429,12 @@ def do_benchmarks(
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return time_sec, tops_sec, pct_top_peak


def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs):
return (
do_bench(
lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median"
)
* 1e3
)
180 changes: 180 additions & 0 deletions test/prototype/blockwise_fp8/test_blockwise_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

triton = pytest.importorskip("triton", reason="Triton required to run this test")

from packaging import version
from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8.kernels import (
blockwise_fp8_gemm_1x128_1x128,
blockwise_fp8_gemm_1x128_128x128,
fp8_blockwise_act_quant,
fp8_blockwise_weight_dequant,
fp8_blockwise_weight_quant,
torch_blockwise_scale_act_quant,
torch_blockwise_scale_weight_quant,
triton_quantize_fp8_block,
)
from torchao.testing.utils import skip_if_rocm

BLOCKWISE_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
def test_blockwise_quant_dequant(_, N, K, dtype):
x = torch.randn(N, K).cuda()
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
sqnr = compute_error(x, x_reconstructed)
assert sqnr >= 25.0, f"SQNR {sqnr:.2f} must be >= 25.0"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
version.parse(triton.__version__) < version.parse("3.3.0"),
reason="Triton version < 3.3.0, test skipped",
)
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()
C = A @ B.T
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
C_q = blockwise_fp8_gemm_1x128_128x128(A_q, A_s, B_q, B_s)
assert not C_q.isnan().any(), "C_q must not contain NaNs"
sqnr = compute_error(C, C_q)
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
version.parse(triton.__version__) < version.parse("3.3.0"),
reason="Triton version < 3.3.0, test skipped",
)
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
def test_blockwise_fp8_gemm_1x128_1x128(M, N, K, dtype):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()
C = A @ B.T
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
B_q, B_s = fp8_blockwise_act_quant(B, dtype=dtype)
C_q = blockwise_fp8_gemm_1x128_1x128(A_q, A_s, B_q, B_s)
assert not C_q.isnan().any(), "C_q must not contain NaNs"
sqnr = compute_error(C, C_q)
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("test_eps", [True, False])
def test_triton_quantize_fp8_act_quant(tile_size: int, test_eps: bool):
device = "cuda"
M, K = 256, 256
x = torch.randn(M, K, device=device)

# set one scaling block to 0s, so if nan guards/EPS are not applied, the
# quantized tensor will have NaNs due to division by 0
if test_eps:
x[0, :tile_size] = 0.0

# Get the quantized tensor and scales using triton implementation
# Use block_m=1 to match the narrow tiles (1 x tile_size) in the reference implementation
triton_fp8, triton_scale = triton_quantize_fp8_block(
x, block_m=1, block_k=tile_size
)
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"

# Get the quantized tensor and scales using reference implementation
ref_fp8, ref_scale = torch_blockwise_scale_act_quant(x, tile_size=tile_size)
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"

# Convert both to float32 for comparison
triton_fp32 = triton_fp8.to(torch.float32)
ref_fp32 = ref_fp8.to(torch.float32)

# Check that the quantized tensors are close
# Note: We use a relatively high tolerance because the implementations might have
# slight differences in how they handle edge cases, rounding, etc.
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
)

# Check that the scales are close
# Note: The scales might be stored differently (reciprocal vs. direct), so we need to
# be careful about how we compare them

# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
# In torch_blockwise_scale_act_quant, scales are stored directly
# So we need to take the reciprocal of one of them for comparison

# Reshape triton_scale to match ref_scale shape for comparison
triton_scale_reshaped = triton_scale.reshape(M, -1)

# Compare reciprocal of triton_scale with ref_scale
assert torch.allclose(
1.0 / triton_scale_reshaped, ref_scale, rtol=1e-2, atol=1e-2
), (
f"Scales differ: max diff = {(1.0 / triton_scale_reshaped - ref_scale).abs().max().item()}"
)


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("test_eps", [True, False])
def test_triton_quantize_fp8_weight_quant(tile_size: int, test_eps: bool):
device = "cuda"
# Make sure dimensions are multiples of tile_size for clean comparison
M = tile_size * 2
K = tile_size * 2
x = torch.randn(M, K, device=device)

# set one scaling block to 0s, so if nan guards/EPS are not applied, the
# quantized tensor will have NaNs due to division by 0
if test_eps:
x[:tile_size, :tile_size] = 0.0

# Get the quantized tensor and scales using triton implementation
triton_fp8, triton_scale = triton_quantize_fp8_block(
x, block_m=tile_size, block_k=tile_size
)
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"

# Get the quantized tensor and scales using reference implementation
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=tile_size)
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"

# Convert both to float32 for comparison
triton_fp32 = triton_fp8.to(torch.float32)
ref_fp32 = ref_fp8.to(torch.float32)

# Check that the quantized tensors are close
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
)

# Check that the scales are close
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
# In torch_blockwise_scale_weight_quant, scales are stored directly

# Compare reciprocal of triton_scale with ref_scale
assert torch.allclose(1.0 / triton_scale, ref_scale, rtol=1e-2, atol=1e-2), (
f"Scales differ: max diff = {(1.0 / triton_scale - ref_scale).abs().max().item()}"
)
Loading
Loading