Skip to content

Commit dd2029e

Browse files
qzzz95akaitsuki-ii
andauthored
support fp8 linear on AMD (#86)
* support fp8 linear on AMD up add comment * move --------- Co-authored-by: zhuguoxuan.zgx <zhuguoxuan.zgx@alibaba-inc.com>
1 parent 3349b65 commit dd2029e

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

diffsynth_engine/utils/fp8_linear.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from contextlib import contextmanager
5+
from diffsynth_engine.utils.platform import DTYPE_FP8
56

67

78
def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False):
@@ -51,7 +52,7 @@ def enable_fp8_linear(module: nn.Module):
5152
def _enable_fp8_linear(module: nn.Module):
5253
if isinstance(module, nn.Linear) and torch.is_floating_point(module.weight.data):
5354
# avoid conversion for int weights like GGUF
54-
module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
55+
module.weight.data = module.weight.data.to(DTYPE_FP8)
5556
for submodule in module.children():
5657
_enable_fp8_linear(submodule)
5758

@@ -71,16 +72,24 @@ def fp8_linear(
7172
) -> torch.Tensor:
7273
device = input.device
7374
origin_dtype = input.dtype
74-
input = input.to(torch.float8_e4m3fn)
75-
weight = weight.to(torch.float8_e4m3fn)
75+
scale_a = 1.0
76+
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
77+
# To avoid overflow and ensure numerical compatibility during FP8 computation,
78+
# we scale down the input by 2.0 in advance.
79+
# This scaling will be compensated later during the final result scaling.
80+
if DTYPE_FP8 == torch.float8_e4m3fnuz:
81+
scale_a = 2.0
82+
input = input / scale_a
83+
input = input.to(DTYPE_FP8)
84+
weight = weight.to(DTYPE_FP8)
7685

7786
if len(input.shape) > 2:
7887
origin_shape = input.shape
7988
input = input.reshape(-1, origin_shape[-1])
8089
result = torch._scaled_mm(
8190
input,
8291
weight.T,
83-
scale_a=torch.tensor(1.0).to(device=device),
92+
scale_a=torch.tensor(scale_a).to(device=device),
8493
scale_b=torch.tensor(1.0).to(device=device),
8594
bias=bias,
8695
out_dtype=origin_dtype,
@@ -91,7 +100,7 @@ def fp8_linear(
91100
result = torch._scaled_mm(
92101
input,
93102
weight.T,
94-
scale_a=torch.tensor(1.0).to(device=device),
103+
scale_a=torch.tensor(scale_a).to(device=device),
95104
scale_b=torch.tensor(1.0).to(device=device),
96105
bias=bias,
97106
out_dtype=origin_dtype,

diffsynth_engine/utils/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
# cross-platform definitions and utilities
12
import torch
23
import gc
34

4-
# 存放跨平台的工具类
5+
6+
# data type
7+
# AMD only supports float8_e4m3fnuz
8+
# https://onnx.ai/onnx/technical/float8.html
9+
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
10+
DTYPE_FP8 = torch.float8_e4m3fnuz
11+
else:
12+
DTYPE_FP8 = torch.float8_e4m3fn
513

614

715
def empty_cache():

0 commit comments

Comments
 (0)