Skip to content

Commit 0db370f

Browse files
committed
Rename torchao.float8.Float8Tensor to torchao.float8.Float8TrainingTensor
Summary: att, since we are introducing a inference version Float8Tensor Test Plan: regression tests for float8 training: pytest test/float8 Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2479, branch: jerryzh168/stack/11
1 parent c336426 commit 0db370f

22 files changed

+160
-141
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ScalingType,
2424
)
2525
from torchao.float8.float8_linear import Float8Linear
26-
from torchao.float8.float8_tensor import ScaledMMConfig
26+
from torchao.float8.float8_training_tensor import ScaledMMConfig
2727

2828
# estimating TOPs for matmuls in fp32, fp16, fp8
2929
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

benchmarks/float8/bench_padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch._inductor.utils import do_bench_using_profiling
1313
from tqdm import tqdm
1414

15-
from torchao.float8.float8_tensor import (
15+
from torchao.float8.float8_training_tensor import (
1616
GemmInputRole,
1717
LinearMMConfig,
1818
ScaledMMConfig,

test/float8/test_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
get_maybe_axiswise_dim,
4141
hp_tensor_to_float8_dynamic,
4242
)
43-
from torchao.float8.float8_tensor import (
44-
Float8Tensor,
43+
from torchao.float8.float8_training_tensor import (
44+
Float8TrainingTensor,
4545
GemmInputRole,
4646
LinearMMConfig,
4747
ScaledMMConfig,
@@ -60,13 +60,13 @@
6060
torch.manual_seed(0)
6161

6262

63-
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
63+
def bitwise_identical(a: Float8TrainingTensor, b: Float8TrainingTensor) -> bool:
6464
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
6565
assert torch.all(a._data == b._data).item(), "data is not identical"
6666
return True
6767

6868

69-
class TestFloat8Tensor:
69+
class TestFloat8TrainingTensor:
7070
def test_preserves_dtype(self) -> None:
7171
# hp means high precision, lp means low precision
7272
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
@@ -128,7 +128,7 @@ def test_copy_(self):
128128
with pytest.raises(RuntimeError):
129129
fp8_a.copy_(b) # Should fail
130130

131-
fp8_b = Float8Tensor(
131+
fp8_b = Float8TrainingTensor(
132132
torch.empty(16, dtype=e4m3_dtype),
133133
scale_a,
134134
torch.bfloat16,

test/float8/test_compile.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
from torchao.float8.float8_scaling_utils import (
3737
hp_tensor_to_float8_dynamic,
3838
)
39-
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
39+
from torchao.float8.float8_training_tensor import (
40+
GemmInputRole,
41+
LinearMMConfig,
42+
ScaledMMConfig,
43+
)
4044
from torchao.testing.training.test_utils import get_test_float8_linear_config
4145

4246

@@ -238,7 +242,7 @@ def forward(self, x):
238242
"CUDA with capability 9.0 or greater not available",
239243
)
240244
def test_float8_with_graph_break_in_the_middle(self):
241-
"""Test that having Float8Tensor object at the boundary of a subgraph"""
245+
"""Test that having Float8TrainingTensor object at the boundary of a subgraph"""
242246
cnts = CompileCounterWithBackend("inductor")
243247
mod = self.MockLinear(graph_break=True).cuda()
244248
compiled_mod = copy.deepcopy(mod)
@@ -254,7 +258,7 @@ def test_float8_with_graph_break_in_the_middle(self):
254258
"CUDA with float8 support not available",
255259
)
256260
def test_float8_graph_input(self):
257-
"""Test that having Float8Tensor object as a graph input"""
261+
"""Test that having Float8TrainingTensor object as a graph input"""
258262

259263
def to_float(x):
260264
return x.to_original_precision()
@@ -278,7 +282,7 @@ def to_float(x):
278282
"CUDA with float8 support not available",
279283
)
280284
def test_float8_graph_output(self):
281-
"""Test that having Float8Tensor object as a graph output works"""
285+
"""Test that having Float8TrainingTensor object as a graph output works"""
282286
cnts = CompileCounterWithBackend("inductor")
283287
mod = self.MockLinear(graph_break=False).cuda()
284288
compiled_mod = torch.compile(mod, backend=cnts)
@@ -290,14 +294,14 @@ def test_float8_graph_output(self):
290294
for tensor in tensors:
291295
assert not isinstance(
292296
getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor
293-
), "Float8Tensor should not contain any FakeTensors!"
297+
), "Float8TrainingTensor should not contain any FakeTensors!"
294298
assert isinstance(y_compiled._orig_dtype, torch.dtype), (
295-
"Float8Tensor._orig_dtype should be a dtype but got {}".format(
299+
"Float8TrainingTensor._orig_dtype should be a dtype but got {}".format(
296300
type(y_compiled._orig_dtype)
297301
)
298302
)
299303
assert isinstance(y_compiled._linear_mm_config.output.emulate, bool), (
300-
"Float8Tensor._emulate should be a bool but got {}".format(
304+
"Float8TrainingTensor._emulate should be a bool but got {}".format(
301305
type(y_compiled._linear_mm_config.output.emulate)
302306
)
303307
)

test/float8/test_dtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
)
3838
from torchao.float8.float8_linear_utils import convert_to_float8_training
3939
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
40-
from torchao.float8.float8_tensor import (
41-
Float8Tensor,
40+
from torchao.float8.float8_training_tensor import (
41+
Float8TrainingTensor,
4242
GemmInputRole,
4343
LinearMMConfig,
4444
hp_tensor_and_scale_to_float8,
@@ -94,8 +94,8 @@ def _test_scaled_mm(mesh: DeviceMesh, size=16):
9494
dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
9595
dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False)
9696

97-
assert isinstance(dist_x_fp8.to_local(), Float8Tensor)
98-
assert isinstance(dist_y_fp8.to_local(), Float8Tensor)
97+
assert isinstance(dist_x_fp8.to_local(), Float8TrainingTensor)
98+
assert isinstance(dist_y_fp8.to_local(), Float8TrainingTensor)
9999
assert dist_x_fp8.to_local()._orig_dtype == torch.float32
100100
out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8)
101101
local_fp8_out = out_fp8.to_local()
@@ -128,7 +128,7 @@ def _test_fp8_redistribute(mesh: DeviceMesh, size=16):
128128
if isinstance(out_local, AsyncCollectiveTensor):
129129
out_local = out_local.wait()
130130

131-
assert isinstance(out_local, Float8Tensor)
131+
assert isinstance(out_local, Float8TrainingTensor)
132132
assert out_local._data.dtype == fp8_dtype
133133

134134

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
4242
from torchao.float8.float8_linear_utils import convert_to_float8_training
4343
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
44-
from torchao.float8.float8_tensor import GemmInputRole
44+
from torchao.float8.float8_training_tensor import GemmInputRole
4545
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
4646
from torchao.testing.training.fsdp2_utils import (
4747
check_parity_bf16_mp,

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
Float8LinearRecipeName,
2525
)
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
27-
from torchao.float8.float8_tensor import LinearMMConfig
27+
from torchao.float8.float8_training_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
2929
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,

test/quantization/test_qat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao import quantize_
2020
from torchao.float8.config import ScalingGranularity
2121
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
22-
from torchao.float8.float8_tensor import LinearMMConfig
22+
from torchao.float8.float8_training_tensor import LinearMMConfig
2323
from torchao.quantization.granularity import (
2424
PerAxis,
2525
PerGroup,
@@ -1696,7 +1696,7 @@ def test_qat_range_learning(self):
16961696

16971697
def test_float8_rowwise_fake_quantize(self):
16981698
"""
1699-
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`.
1699+
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
17001700
"""
17011701
torch.manual_seed(self.SEED)
17021702
dtype = torch.float8_e4m3fn

torchao/dtypes/nf4tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def allowed_subclasses(type):
943943
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
944944
)
945945

946-
# Do not force the Float8Tensor type on the returned tensor
946+
# Do not force the Float8TrainingTensor type on the returned tensor
947947

948948
@classmethod
949949
def __torch_function__(cls, func, types, args=(), kwargs=None):

torchao/float8/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
_auto_filter_for_recipe,
1111
convert_to_float8_training,
1212
)
13-
from torchao.float8.float8_tensor import (
14-
Float8Tensor,
13+
from torchao.float8.float8_training_tensor import (
14+
Float8TrainingTensor,
1515
GemmInputRole,
1616
LinearMMConfig,
1717
ScaledMMConfig,
@@ -22,12 +22,12 @@
2222
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2323

2424
if TORCH_VERSION_AT_LEAST_2_5:
25-
# Needed to load Float8Tensor with weights_only = True
25+
# Needed to load Float8TrainingTensor with weights_only = True
2626
from torch.serialization import add_safe_globals
2727

2828
add_safe_globals(
2929
[
30-
Float8Tensor,
30+
Float8TrainingTensor,
3131
ScaledMMConfig,
3232
GemmInputRole,
3333
LinearMMConfig,
@@ -50,5 +50,5 @@
5050
"_auto_filter_for_recipe",
5151
# types
5252
"FP8Granularity",
53-
# note: Float8Tensor and Float8Linear are not public APIs
53+
# note: Float8TrainingTensor and Float8Linear are not public APIs
5454
]

0 commit comments

Comments
 (0)