Skip to content

Commit 769ffa5

Browse files
authored
add cast config for fp8 enablement
Differential Revision: D75945415 Pull Request resolved: #2328
1 parent 83663b8 commit 769ffa5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchao/float8/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
CastConfig,
44
Float8GemmConfig,
55
Float8LinearConfig,
6+
ScalingGranularity,
67
ScalingType,
78
)
8-
from torchao.float8.float8_linear_utils import (
9-
convert_to_float8_training,
10-
)
9+
from torchao.float8.float8_linear_utils import convert_to_float8_training
1110
from torchao.float8.float8_tensor import (
1211
Float8Tensor,
1312
GemmInputRole,
@@ -39,6 +38,7 @@
3938
"Float8GemmConfig",
4039
"Float8LinearConfig",
4140
"CastConfig",
41+
"ScalingGranularity",
4242
# top level UX
4343
"convert_to_float8_training",
4444
"precompute_float8_dynamic_scale_for_fsdp",

0 commit comments

Comments
 (0)