Skip to content

WIP: Generalize base and compile tests of float8 #2498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 58 additions & 55 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_7,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
if not TORCH_VERSION_AT_LEAST_2_7:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down Expand Up @@ -237,11 +237,12 @@ def test_axiswise_reshape(self):
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_90() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
a = torch.randn(*a_shape, dtype=torch.bfloat16, device=torch.accelerator.current_accelerator().type)
b = torch.randn(64, 32, dtype=torch.bfloat16, device=torch.accelerator.current_accelerator().type)

linear_mm_config = LinearMMConfig()

Expand Down Expand Up @@ -270,7 +271,7 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_fp8_dtype(
self,
):
Expand Down Expand Up @@ -317,7 +318,8 @@ def _test_linear_impl(
torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad)

@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
"emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda"
and not is_sm_at_least_89()) else [True, False]
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
Expand All @@ -335,7 +337,7 @@ def _test_linear_impl(
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@pytest.mark.parametrize("use_ac", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_linear_from_config_params(
self,
x_shape,
Expand All @@ -347,8 +349,8 @@ def test_linear_from_config_params(
linear_bias: bool,
use_ac: bool,
):
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)

config = get_test_float8_linear_config(
scaling_type_input,
Expand Down Expand Up @@ -380,7 +382,7 @@ def test_linear_from_config_params(
@pytest.mark.parametrize(
"linear_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
@skip_if_rocm("ROCm enablement in progress")
def test_linear_from_recipe(
self,
Expand All @@ -389,14 +391,9 @@ def test_linear_from_recipe(
linear_dtype: torch.dtype,
linear_bias: bool,
):
if torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device=torch.accelerator.current_accelerator().type, dtype=linear_dtye
config = Float8LinearConfig.from_recipe_name(recipe_name)
self._test_linear_impl(
x,
Expand All @@ -405,37 +402,38 @@ def test_linear_from_recipe(
)

@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
"emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda"
and not is_sm_at_least_89()) else [True, False]
)
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_autocast_outputs(
self,
emulate: bool,
linear_dtype: torch.dtype,
):
m_ref = nn.Sequential(
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
nn.Linear(32, 32, device="cuda", dtype=linear_dtype),
nn.Linear(32, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype),
nn.Linear(32, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype),
)
config = Float8LinearConfig(
emulate=emulate,
)
m = convert_to_float8_training(copy.deepcopy(m_ref), config=config)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
x = torch.randn(16, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
with torch.autocast(torch.accelerator.current_accelerator().type):
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast(torch.accelerator.current_accelerator().type, dtype=torch.bfloat16):
y = m(x)
assert y.dtype == torch.bfloat16, (
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
Expand All @@ -445,28 +443,29 @@ def test_autocast_outputs(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
"emulate", [True] if (torch.accelerator.current_accelerator().type=="cuda"
and not is_sm_at_least_89()) else [True, False]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = nn.Linear(32, 16, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)
config = Float8LinearConfig(emulate=emulate)
m = Float8Linear.from_float(copy.deepcopy(m), config)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
x = torch.randn(16, 32, device=torch.accelerator.current_accelerator().type, dtype=linear_dtype)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
with torch.autocast(torch.accelerator.current_accelerator().type):
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast(torch.accelerator.current_accelerator().type, dtype=torch.bfloat16):
y = m(x)
assert y.dtype == torch.bfloat16, (
f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
Expand All @@ -484,18 +483,22 @@ def test_repr(self):
s = m.__repr__()
assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s

@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, arch 8.9 not available")
def test_inference_mode(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
x = torch.randn(32, 32, device=torch.accelerator.current_accelerator().type)
m = nn.Sequential(nn.Linear(32, 32)).to(device=torch.accelerator.current_accelerator().type)
m = convert_to_float8_training(m)
with torch.inference_mode(mode=True):
m(x)

@unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available")
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, arch 8.9 not available")
def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
x = torch.randn(32, 32, device=torch.accelerator.current_accelerator().type)
m = nn.Sequential(nn.Linear(32, 32)).to(device=torch.accelerator.current_accelerator().type)
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_
Expand All @@ -509,10 +512,9 @@ def test_quantize(self):


class TestScaledMM:
@unittest.skipIf(
not is_sm_at_least_89(),
"CUDA not available",
)
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, arch 8.9 not available")
@pytest.mark.parametrize(
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand All @@ -523,8 +525,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
output_dtype = base_dtype
compare_type = torch.float32

a = torch.randn(16, 16, device="cuda", dtype=base_dtype)
b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
a = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type, dtype=base_dtype)
b = torch.randn(32, 16, device=torch.accelerator.current_accelerator().type, dtype=base_dtype).t()

a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()
Expand Down Expand Up @@ -555,10 +557,12 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, arch 8.9 not available")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
x_fp32 = torch.randn(16, 16, device=torch.accelerator.current_accelerator().type)
x_scale = torch.tensor(1.0, device=torch.accelerator.current_accelerator().type)
fp8_dtype = e4m3_dtype
linear_config_a = LinearMMConfig(
ScaledMMConfig(False, True, False, False),
Expand Down Expand Up @@ -590,10 +594,9 @@ def test_different_configs_error(self):
):
a @ b

@unittest.skipIf(
not is_sm_at_least_89(),
"CUDA not available",
)
@unittest.skipIf(not torch.accelerator.is_available() or
(not is_sm_at_least_89() and torch.accelerator.current_accelerator().type == "cuda"),
"Accelerator not available or If CUDA, arch 8.9 not available")
@pytest.mark.parametrize(
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand All @@ -603,8 +606,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
input_dtype = e4m3_dtype
compare_type = torch.float32

a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)
a = torch.randn(16, 41, device=torch.accelerator.current_accelerator().type, dtype=base_dtype)
b = torch.randn(41, 128, device=torch.accelerator.current_accelerator().type, dtype=base_dtype)

a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()
Expand Down Expand Up @@ -682,7 +685,7 @@ class TestNumerics:
torch.float8_e5m2fnuz,
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not torch.accelerator.is_available(), "Accelerator not available")
def test_small_amax_float16(self, float8_dtype):
# If we calculate scale naively with FP8_MAX_POS / amax,
# the result may not be representable in fp16. Verify that
Expand All @@ -701,7 +704,7 @@ def test_small_amax_float16(self, float8_dtype):
FP16_MAX_POS = torch.finfo(torch.float16).max

target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12)
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
x = torch.tensor([target_amax], dtype=torch.float16, device=torch.accelerator.current_accelerator().type)
scale = tensor_to_scale(x, float8_dtype)
assert not torch.any(torch.isinf(scale))

Expand Down
Loading