Skip to content

Commit d6cfdad

Browse files
authored
Fix broken circular dep error
Differential Revision: D75980282 Pull Request resolved: #2320
1 parent 488ecd4 commit d6cfdad

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

torchao/float8/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
1818
from torchao.float8.inference import Float8MMConfig
19+
from torchao.float8.types import FP8Granularity
1920
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2021

2122
if TORCH_VERSION_AT_LEAST_2_5:
@@ -41,5 +42,7 @@
4142
# top level UX
4243
"convert_to_float8_training",
4344
"precompute_float8_dynamic_scale_for_fsdp",
45+
# types
46+
"FP8Granularity",
4447
# note: Float8Tensor and Float8Linear are not public APIs
4548
]

torchao/float8/inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313

1414
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
15+
from torchao.float8.types import FP8Granularity
1516
from torchao.quantization.granularity import (
1617
PerRow,
1718
PerTensor,
@@ -116,9 +117,6 @@ def _is_rowwise_scaled(x) -> bool:
116117
return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],)
117118

118119

119-
FP8Granularity = Union[PerTensor, PerRow]
120-
121-
122120
def _normalize_granularity(
123121
granularity: Optional[
124122
Union[

torchao/float8/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Common types for float8 quantization
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import TYPE_CHECKING, Union
13+
14+
if TYPE_CHECKING:
15+
from torchao.quantization.granularity import PerRow, PerTensor
16+
17+
18+
# Define FP8Granularity type alias to break circular import dependencies
19+
FP8Granularity = Union["PerTensor", "PerRow"]

0 commit comments

Comments
 (0)