Skip to content

Commit 63d142c

Browse files
authored
Lint fixes torchao/profiler and torchao/testing (#1368)
1 parent 2f97b09 commit 63d142c

File tree

6 files changed

+79
-60
lines changed

6 files changed

+79
-60
lines changed

ruff.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ include = [
77
"torchao/quantization/**/*.py",
88
"torchao/dtypes/**/*.py",
99
"torchao/sparsity/**/*.py",
10+
"torchao/profiler/**/*.py",
11+
"torchao/testing/**/*.py",
1012
"torchao/prototype/low_bit_optim/**.py",
1113
"torchao/utils.py",
1214
"torchao/ops.py",

torchao/profiler/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Re-exports
32
from .device_spec import CUDADeviceSpec, DeviceSpec
43
from .performance_counter import (
@@ -20,4 +19,3 @@
2019
"DeviceSpec",
2120
"total_model_params",
2221
]
23-

torchao/testing/float8/dtensor_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
87
import torch.nn as nn
98
import torch.nn.functional as F
109

torchao/testing/float8/fsdp2_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
import contextlib
2-
from typing import List, Optional
1+
from typing import List
32

43
import torch
54
import torch.distributed as dist
65
import torch.nn as nn
76

8-
import torchao.float8.config as config
97
from torchao.float8.config import (
108
Float8LinearConfig,
119
ScalingType,
1210
)
13-
1411
from torchao.float8.float8_linear_utils import (
1512
linear_requires_sync,
1613
sync_float8_amax_and_scale_history,
@@ -52,7 +49,11 @@ def check_parity_no_mp(
5249
):
5350
precompute_float8_dynamic_scale_for_fsdp(model)
5451

55-
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
52+
test_cls.assertEqual(
53+
losses[0],
54+
losses[1],
55+
msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
56+
)
5657

5758

5859
def check_parity_bf16_mp(
@@ -87,7 +88,11 @@ def check_parity_bf16_mp(
8788
ref_model.parameters(), ref_model_bf16.parameters()
8889
):
8990
param_bf16.detach().copy_(param_fp32)
90-
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
91+
test_cls.assertEqual(
92+
losses[0],
93+
losses[1],
94+
msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
95+
)
9196

9297

9398
def check_parity_fp8_comm_only(
@@ -104,7 +109,6 @@ def check_parity_fp8_comm_only(
104109
for iter_idx in range(10):
105110
losses: List[torch.Tensor] = []
106111
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
107-
108112
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
109113
losses.append(model(local_inp).sum())
110114
losses[-1].backward()
@@ -123,9 +127,15 @@ def check_parity_fp8_comm_only(
123127
and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
124128
):
125129
precompute_float8_dynamic_scale_for_fsdp(model)
126-
130+
127131
if compile:
128132
# When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now.
129-
assert (torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}"
133+
assert (
134+
torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()
135+
), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}"
130136
else:
131-
test_cls.assertEqual(losses[0], losses[1], f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
137+
test_cls.assertEqual(
138+
losses[0],
139+
losses[1],
140+
f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
141+
)

torchao/testing/float8/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
2+
23
from torchao.float8.config import (
3-
ScalingGranularity,
4-
ScalingType,
5-
CastConfig,
4+
CastConfig,
65
Float8LinearConfig,
6+
ScalingType,
77
)
88

99

torchao/testing/utils.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
import unittest
2-
import functools
31
import copy
4-
import torch
5-
import torchao
6-
import os
2+
import functools
3+
import unittest
74

5+
import torch
6+
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
87
from torch.testing._internal import common_utils
9-
from torchao.dtypes import AffineQuantizedTensor
10-
from torchao.dtypes import to_affine_quantized_intx
8+
from torch.testing._internal.distributed._tensor.common_dtensor import (
9+
DTensorTestBase,
10+
with_comms,
11+
)
12+
13+
import torchao
14+
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
15+
from torchao.quantization import int8_weight_only, quantize_
1116
from torchao.quantization.quant_primitives import MappingType
12-
from torchao.quantization import quantize_, int8_weight_only
1317
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
1418

1519
"""
@@ -36,10 +40,9 @@ class MyTestCase(TorchAOBasicTestCase):
3640
unittest.main()
3741
"""
3842

43+
3944
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
40-
def copy_tests(
41-
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
42-
): # noqa: B902
45+
def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902
4346
for name, value in my_cls.__dict__.items():
4447
if name.startswith("test_"):
4548
# You cannot copy functions in Python, so we use closures here to
@@ -70,7 +73,6 @@ def new_test(self, value=value):
7073
setattr(other_cls, f"{name}_{suffix}", new_test)
7174

7275

73-
7476
class TorchAOBasicTestCase(common_utils.TestCase):
7577
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
7678
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
@@ -90,17 +92,21 @@ def test_flatten_unflatten(self):
9092
hp_tensor = torch.randn(4, 128)
9193
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
9294
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
93-
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
95+
tensor_data_dict = {
96+
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
97+
}
9498
outer_size = lp_tensor.size()
9599
outer_stride = lp_tensor.stride()
96-
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
100+
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(
101+
tensor_data_dict, tensor_attributes, outer_size, outer_stride
102+
)
97103
self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize())
98104

99105
@common_utils.parametrize("device", COMMON_DEVICES)
100106
@common_utils.parametrize("dtype", COMMON_DTYPES)
101107
def test_hp_tensor_device_dtype(self, device, dtype):
102108
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
103-
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
109+
self.FACTORY_FN(hp_tensor, **self.kwargs)
104110

105111
@common_utils.parametrize("device1", COMMON_DEVICES)
106112
@common_utils.parametrize("device2", COMMON_DEVICES)
@@ -141,7 +147,10 @@ def test_linear(self, device, dtype):
141147
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
142148
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
143149
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
144-
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
150+
self.assertGreater(
151+
torchao.quantization.utils.compute_error(hp_res, lp_res),
152+
self.LINEAR_MIN_SQNR,
153+
)
145154

146155

147156
class TorchAOCompileTestCase(common_utils.TestCase):
@@ -165,6 +174,7 @@ class TorchAOCompileTestCase(common_utils.TestCase):
165174
def test_input_output_tensor_subclass(self, device, dtype):
166175
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
167176
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
177+
168178
def f(tensor):
169179
return tensor
170180

@@ -179,6 +189,7 @@ def f(tensor):
179189
def test_input_tensor_subclass(self, device, dtype):
180190
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
181191
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
192+
182193
def f(tensor):
183194
return tensor.dequantize()
184195

@@ -192,6 +203,7 @@ def f(tensor):
192203
@common_utils.parametrize("dtype", COMMON_DTYPES)
193204
def test_output_tensor_subclass(self, device, dtype):
194205
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
206+
195207
def f(hp_tensor):
196208
return self.FACTORY_FN(hp_tensor, **self.kwargs)
197209

@@ -201,7 +213,12 @@ def f(hp_tensor):
201213
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
202214
# bfloat16 seems to result in much larger numerical differences
203215
if dtype != torch.bfloat16:
204-
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)
216+
self.assertGreater(
217+
torchao.quantization.utils.compute_error(
218+
ref.dequantize(), compiled.dequantize()
219+
),
220+
self.COMPILE_MIN_SQNR,
221+
)
205222

206223
@common_utils.parametrize("device", COMMON_DEVICES)
207224
@common_utils.parametrize("dtype", COMMON_DTYPES)
@@ -211,22 +228,18 @@ def test_linear_compile(self, device, dtype):
211228

212229
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
213230
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
214-
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
215-
l.weight = torch.nn.Parameter(lp_tensor)
216-
lp_res = torch.compile(l)(hp_act_tensor)
217-
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
231+
linear = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
232+
linear.weight = torch.nn.Parameter(lp_tensor)
233+
lp_res = torch.compile(linear)(hp_act_tensor)
234+
self.assertGreater(
235+
torchao.quantization.utils.compute_error(hp_res, lp_res),
236+
self.LINEAR_MIN_SQNR,
237+
)
218238

219-
import torch.distributed as dist
220-
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
221-
from torch.testing._internal.distributed._tensor.common_dtensor import (
222-
DTensorTestBase,
223-
with_comms,
224-
NUM_DEVICES,
225-
)
226239

227240
class TorchAOTensorParallelTestCase(DTensorTestBase):
228-
"""Basic test case for tensor subclasses
229-
"""
241+
"""Basic test case for tensor subclasses"""
242+
230243
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
231244

232245
TENSOR_SUBCLASS = AffineQuantizedTensor
@@ -247,9 +260,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
247260
# Construct DTensor from local shard
248261
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
249262
# Replace parameter in module
250-
m.linear.weight = torch.nn.Parameter(
251-
dtensor, requires_grad=False
252-
)
263+
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
253264
return m
254265

255266
@staticmethod
@@ -266,9 +277,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
266277
# Construct DTensor from local shard
267278
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
268279
# Replace parameter in module
269-
m.linear.weight = torch.nn.Parameter(
270-
dtensor, requires_grad=False
271-
)
280+
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
272281
return m
273282

274283
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
@@ -289,7 +298,9 @@ def test_tp(self, dtype):
289298
class M(torch.nn.Module):
290299
def __init__(self, in_features, out_features, **kwargs) -> None:
291300
super().__init__(**kwargs)
292-
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
301+
self.linear = torch.nn.Linear(
302+
in_features, out_features, bias=False, device="cuda"
303+
)
293304

294305
def forward(self, x: torch.Tensor) -> torch.Tensor:
295306
return self.linear(x)
@@ -301,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301312
proj_up = M(1024, 2048).to(device).to(dtype)
302313
proj_dn = M(2048, 1024).to(device).to(dtype)
303314
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304-
y = proj_dn(proj_up(example_input))
315+
proj_dn(proj_up(example_input))
305316

306317
# Quantize the model
307318
up_quant = self.quantize(proj_up)
308319
dn_quant = self.quantize(proj_dn)
309-
y_q = dn_quant(up_quant(example_input))
320+
dn_quant(up_quant(example_input))
310321

311322
mesh = self.build_device_mesh()
312323
mesh.device_type = "cuda"
@@ -316,11 +327,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
316327
dn_dist = self.rowwise_shard(dn_quant, mesh)
317328

318329
# We need to turn inputs into DTensor form as well -- just a format change
319-
input_dtensor = DTensor.from_local(
320-
example_input, mesh, [Replicate()]
321-
)
330+
input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()])
322331

323-
y_d = dn_dist(up_dist(input_dtensor))
332+
dn_dist(up_dist(input_dtensor))
324333

325334
if not TORCH_VERSION_AT_LEAST_2_6:
326335
# Need torch 2.6 to support compiled tensor parallelism
@@ -329,7 +338,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
329338
up_compiled = torch.compile(up_dist)
330339
y_up = up_compiled(input_dtensor)
331340
dn_compiled = torch.compile(dn_dist)
332-
y_dn = dn_compiled(y_up)
341+
dn_compiled(y_up)
342+
333343

334344
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
335345
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

0 commit comments

Comments
 (0)