2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
from contextlib import contextmanager
5
+ from diffsynth_engine .utils .platform import DTYPE_FP8
5
6
6
7
7
8
def enable_fp8_autocast (module : nn .Module , compute_dtype : torch .dtype = torch .bfloat16 , use_fp8_linear : bool = False ):
@@ -51,7 +52,7 @@ def enable_fp8_linear(module: nn.Module):
51
52
def _enable_fp8_linear (module : nn .Module ):
52
53
if isinstance (module , nn .Linear ) and torch .is_floating_point (module .weight .data ):
53
54
# avoid conversion for int weights like GGUF
54
- module .weight .data = module .weight .data .to (torch . float8_e4m3fn )
55
+ module .weight .data = module .weight .data .to (DTYPE_FP8 )
55
56
for submodule in module .children ():
56
57
_enable_fp8_linear (submodule )
57
58
@@ -71,16 +72,24 @@ def fp8_linear(
71
72
) -> torch .Tensor :
72
73
device = input .device
73
74
origin_dtype = input .dtype
74
- input = input .to (torch .float8_e4m3fn )
75
- weight = weight .to (torch .float8_e4m3fn )
75
+ scale_a = 1.0
76
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
77
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
78
+ # we scale down the input by 2.0 in advance.
79
+ # This scaling will be compensated later during the final result scaling.
80
+ if DTYPE_FP8 == torch .float8_e4m3fnuz :
81
+ scale_a = 2.0
82
+ input = input / scale_a
83
+ input = input .to (DTYPE_FP8 )
84
+ weight = weight .to (DTYPE_FP8 )
76
85
77
86
if len (input .shape ) > 2 :
78
87
origin_shape = input .shape
79
88
input = input .reshape (- 1 , origin_shape [- 1 ])
80
89
result = torch ._scaled_mm (
81
90
input ,
82
91
weight .T ,
83
- scale_a = torch .tensor (1.0 ).to (device = device ),
92
+ scale_a = torch .tensor (scale_a ).to (device = device ),
84
93
scale_b = torch .tensor (1.0 ).to (device = device ),
85
94
bias = bias ,
86
95
out_dtype = origin_dtype ,
@@ -91,7 +100,7 @@ def fp8_linear(
91
100
result = torch ._scaled_mm (
92
101
input ,
93
102
weight .T ,
94
- scale_a = torch .tensor (1.0 ).to (device = device ),
103
+ scale_a = torch .tensor (scale_a ).to (device = device ),
95
104
scale_b = torch .tensor (1.0 ).to (device = device ),
96
105
bias = bias ,
97
106
out_dtype = origin_dtype ,
0 commit comments