Skip to content

Commit cc8e80b

Browse files
authored
Lint tutorials (#1520)
1 parent 070345d commit cc8e80b

15 files changed

+4547
-1460
lines changed

ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include = [
66
"torchao/**/*.py",
77
"test/**/*.py",
88
"benchmarks/**/*.py",
9+
"tutorials/**/*.py",
910
]
1011

1112
exclude = [

tutorials/add_an_op.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
23
import torchao
34
from torchao.dtypes import to_nf4
45

@@ -20,6 +21,7 @@
2021
# NotImplementedError: NF4Tensor dispatch: attempting to run aten.gelu.default, this is not supported
2122
# torch.nn.functional.gelu(a_nf4)
2223

24+
2325
# Next you can add this function using the implements decorator
2426
@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])
2527
def gelu(func, *args, **kwargs):
@@ -30,7 +32,12 @@ def gelu(func, *args, **kwargs):
3032
# We're getting the first argument of the original args
3133
inp = args[0][0]
3234
# There's a way very inefficient way to implement it
33-
return to_nf4(torch.nn.functional.gelu(inp.to(torch.float32)), inp.block_size, inp.scaler_block_size)
35+
return to_nf4(
36+
torch.nn.functional.gelu(inp.to(torch.float32)),
37+
inp.block_size,
38+
inp.scaler_block_size,
39+
)
40+
3441

3542
print(f"gelu(a): {torch.nn.functional.gelu(a)}")
3643
print(f"gelu(a_nf4): {torch.nn.functional.gelu(a_nf4)}")

tutorials/calibration_flow/awq_like.py

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,47 @@
66
* then we apply equalization scale to linear activation with to_weight_tensor_with_linear_activation_scale_metadata (input activation will be divided by equalization_scale), and then call F.linear with
77
scaled input activation and quantized weight (so we can reuse the efficient quantized linear kernels used by quantized weight)
88
"""
9-
import torch
9+
1010
import copy
1111

12+
import torch
1213
import torch.nn.functional as F
1314
from torch import Tensor
15+
1416
from torchao.dtypes import (
15-
to_affine_quantized_intx_static,
16-
to_affine_quantized_floatx_static,
1717
Float8Layout,
18+
to_affine_quantized_floatx_static,
19+
to_affine_quantized_intx_static,
1820
)
19-
from torchao.quantization.utils import compute_error
20-
from torchao.quantization import quantize_
21-
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
22-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
23-
from torchao.quantization.observer import (
24-
AffineQuantizedMinMaxObserver,
21+
from torchao.quantization import (
22+
quantize_,
23+
to_weight_tensor_with_linear_activation_scale_metadata,
2524
)
2625
from torchao.quantization.granularity import (
2726
PerAxis,
2827
PerTensor,
2928
)
29+
from torchao.quantization.observer import (
30+
AffineQuantizedMinMaxObserver,
31+
)
32+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
3033
from torchao.quantization.quant_primitives import (
3134
MappingType,
32-
FP8_TYPES,
3335
)
36+
from torchao.quantization.utils import compute_error
3437

3538

3639
class ObservedLinear(torch.nn.Linear):
37-
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None):
40+
def __init__(
41+
self,
42+
in_features: int,
43+
out_features: int,
44+
act_obs: torch.nn.Module,
45+
weight_obs: torch.nn.Module,
46+
bias: bool = True,
47+
device=None,
48+
dtype=None,
49+
):
3850
super().__init__(in_features, out_features, bias, device, dtype)
3951
self.act_obs = act_obs
4052
self.weight_obs = weight_obs
@@ -46,11 +58,20 @@ def forward(self, input: Tensor):
4658

4759
@classmethod
4860
def from_float(cls, float_linear, act_obs, weight_obs):
49-
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
61+
observed_linear = cls(
62+
float_linear.in_features,
63+
float_linear.out_features,
64+
act_obs,
65+
weight_obs,
66+
False,
67+
device=float_linear.weight.device,
68+
dtype=float_linear.weight.dtype,
69+
)
5070
observed_linear.weight = float_linear.weight
5171
observed_linear.bias = float_linear.bias
5272
return observed_linear
5373

74+
5475
def insert_observers_(model, act_obs, weight_obs):
5576
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
5677

@@ -61,22 +82,39 @@ def replacement_fn(m):
6182

6283
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
6384

85+
6486
# converting observed linear module to linear module with quantzied weights (and quantized activations)
6587
# with tensor subclasses
6688
def apply_awq(target_dtype: torch.dtype):
6789
# target_dtype = torch.uint8
6890
def _apply_awq_to_linear(observed_linear):
6991
# weight quantization
7092
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
93+
7194
def weight_quant_func(weight):
7295
block_size = (1, weight.shape[1])
7396
if target_dtype == torch.uint8:
74-
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
97+
return to_affine_quantized_intx_static(
98+
weight, weight_scale, weight_zero_point, block_size, target_dtype
99+
)
75100
elif target_dtype == torch.float8_e4m3fn:
76-
return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8Layout(mm_config=None))
101+
return to_affine_quantized_floatx_static(
102+
weight,
103+
weight_scale,
104+
block_size,
105+
target_dtype,
106+
Float8Layout(mm_config=None),
107+
)
77108
else:
78109
raise ValueError(f"Unsupported target dtype {target_dtype}")
79-
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
110+
111+
linear = torch.nn.Linear(
112+
observed_linear.in_features,
113+
observed_linear.out_features,
114+
False,
115+
device=observed_linear.weight.device,
116+
dtype=observed_linear.weight.dtype,
117+
)
80118
linear.weight = observed_linear.weight
81119
linear.bias = observed_linear.bias
82120

@@ -86,16 +124,22 @@ def weight_quant_func(weight):
86124
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
87125
equalization_scale = torch.ones_like(equalization_scale)
88126

89-
linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight * equalization_scale), requires_grad=False)
127+
linear.weight = torch.nn.Parameter(
128+
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
129+
)
90130

91-
linear.weight = torch.nn.Parameter(to_weight_tensor_with_linear_activation_scale_metadata(linear.weight, equalization_scale), requires_grad=False)
131+
linear.weight = torch.nn.Parameter(
132+
to_weight_tensor_with_linear_activation_scale_metadata(
133+
linear.weight, equalization_scale
134+
),
135+
requires_grad=False,
136+
)
92137

93138
return linear
94139

95140
return _apply_awq_to_linear
96141

97142

98-
99143
######## Test ##########
100144
class ToyLinearModel(torch.nn.Module):
101145
def __init__(self, m=64, n=32, k=64):
@@ -104,7 +148,11 @@ def __init__(self, m=64, n=32, k=64):
104148
self.linear2 = torch.nn.Linear(k, n, bias=False)
105149

106150
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
107-
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
151+
return (
152+
torch.randn(
153+
batch_size, self.linear1.in_features, dtype=dtype, device=device
154+
),
155+
)
108156

109157
def forward(self, x):
110158
x = self.linear1(x)
@@ -119,16 +167,24 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
119167
dtype = torch.bfloat16
120168
m = ToyLinearModel().eval().to(dtype).to("cuda")
121169

122-
m_for_test = copy.deepcopy(m)
123-
124170
m_bf16 = copy.deepcopy(m)
125171
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
126172
print("example inputs shape:", example_inputs[0].shape)
127173

128-
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
129-
130-
act_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps)
131-
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps)
174+
m_bf16 = torch.compile(m_bf16, mode="max-autotune")
175+
176+
act_obs = AffineQuantizedMinMaxObserver(
177+
mapping_type,
178+
target_dtype,
179+
granularity_type=PerTensor(),
180+
eps=torch.finfo(torch.float32).eps,
181+
)
182+
weight_obs = AffineQuantizedMinMaxObserver(
183+
mapping_type,
184+
target_dtype,
185+
granularity_type=PerAxis(axis=0),
186+
eps=torch.finfo(torch.float32).eps,
187+
)
132188

133189
before_quant = m(*example_inputs)
134190

@@ -137,9 +193,9 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
137193
for _ in range(10):
138194
m(*example_inputs)
139195

140-
after_obs = m(*example_inputs)
196+
m(*example_inputs)
141197

142-
m2 = copy.deepcopy(m)
198+
copy.deepcopy(m)
143199

144200
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
145201

0 commit comments

Comments
 (0)