Skip to content

Commit 44b1543

Browse files
committed
ruff format
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f9a218f commit 44b1543

File tree

5 files changed

+49
-22
lines changed

5 files changed

+49
-22
lines changed

torchao/_models/mixtral-moe/generate.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def main(
238238

239239
from torchao.quantization.prototype.moe_quant.utils import (
240240
MoEQuantConfig,
241+
UseFakeExtraDimTensor,
241242
cond_ffn_filter,
242-
UseFakeExtraDimTensor
243243
)
244244
from torchao.quantization.quant_api import (
245245
Float8DynamicActivationFloat8WeightConfig,
@@ -260,28 +260,42 @@ def main(
260260
config = MoEQuantConfig(Int8WeightOnlyConfig())
261261

262262
elif "int8wo" in moe_quant:
263-
config = MoEQuantConfig(Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
263+
config = MoEQuantConfig(
264+
Int8WeightOnlyConfig(),
265+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
266+
)
264267

265268
elif "int8dq-base" in moe_quant:
266269
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
267270

268271
elif "int8dq" in moe_quant:
269-
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
272+
config = MoEQuantConfig(
273+
Int8DynamicActivationInt8WeightConfig(),
274+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
275+
)
270276

271277
elif "int4wo-base" in moe_quant:
272278
config = MoEQuantConfig(Int4WeightOnlyConfig())
273279

274280
elif "int4wo" in moe_quant:
275-
config = MoEQuantConfig(Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
281+
config = MoEQuantConfig(
282+
Int4WeightOnlyConfig(),
283+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
284+
)
276285

277286
elif "fp8wo-base" in moe_quant:
278287
config = MoEQuantConfig(Float8WeightOnlyConfig())
279288

280289
elif "fp8wo" in moe_quant:
281-
config = MoEQuantConfig(Float8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
290+
config = MoEQuantConfig(
291+
Float8WeightOnlyConfig(),
292+
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
293+
)
282294

283295
elif "fp8dq-base" in moe_quant:
284-
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
296+
config = MoEQuantConfig(
297+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
298+
)
285299

286300
elif "fp8dq" in moe_quant:
287301
config = MoEQuantConfig(

torchao/dtypes/floatx/float8_layout.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ class Float8Layout(Layout):
5555

5656
mm_config: Optional[Float8MMConfig] = None
5757

58+
5859
_fallback_warning_shown = False
5960

61+
6062
@register_layout(Float8Layout)
6163
class Float8AQTTensorImpl(AQTTensorImpl):
6264
"""
@@ -102,7 +104,7 @@ def __init__(
102104
def _apply_fn_to_data(self, fn):
103105
"""Applys a fn to all tensor components stored on this class"""
104106
global _fallback_warning_shown
105-
107+
106108
try:
107109
return self.__class__(
108110
fn(self.float8_data),
@@ -114,14 +116,15 @@ def _apply_fn_to_data(self, fn):
114116
if '"index_cuda" not implemented for ' in str(e):
115117
if not _fallback_warning_shown:
116118
import warnings
119+
117120
warnings.warn(
118121
f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but "
119122
+ "note: You can torch.compile the model to avoid this problem.",
120-
UserWarning
123+
UserWarning,
121124
)
122125
_fallback_warning_shown = True
123-
124-
return self.__class__( # do indexing in bfloat16 then convert back
126+
127+
return self.__class__( # do indexing in bfloat16 then convert back
125128
fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype),
126129
fn(self.scale),
127130
self.transposed,

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,8 @@ def test_moe_quant_intx(self):
636636
from torchao.quantization.prototype.moe_quant.utils import (
637637
FakeExtraDimTensor,
638638
MoEQuantConfig,
639-
cond_ffn_filter,
640639
UseFakeExtraDimTensor,
640+
cond_ffn_filter,
641641
)
642642
from torchao.quantization.quant_api import (
643643
Int8DynamicActivationIntxWeightConfig,
@@ -657,7 +657,9 @@ def test_moe_quant_intx(self):
657657
base_config = Int8DynamicActivationIntxWeightConfig(
658658
layout=PackedLinearInt8DynamicActivationIntxWeightLayout()
659659
)
660-
moe_config = MoEQuantConfig(base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE)
660+
moe_config = MoEQuantConfig(
661+
base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
662+
)
661663

662664
quantize_(model, moe_config, cond_ffn_filter)
663665

torchao/quantization/prototype/moe_quant/llama4_quant.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def convert_fn(module):
7070
model = model
7171

7272
from torchao.quantization import Int4WeightOnlyConfig, quantize_
73-
from torchao.quantization.prototype.moe_quant.utils import cond_ffn_filter, MoEQuantConfig
73+
from torchao.quantization.prototype.moe_quant.utils import (
74+
MoEQuantConfig,
75+
cond_ffn_filter,
76+
)
7477

7578
quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda")
7679

torchao/quantization/prototype/moe_quant/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66
aten = torch.ops.aten
77

8+
from enum import Enum, auto
89
from typing import List, Optional, Tuple, Union
910

1011
from torchao.quantization.quant_api import (
12+
_QUANTIZE_CONFIG_HANDLER,
1113
AOBaseConfig,
1214
dataclass,
1315
register_quantize_module_handler,
1416
)
1517
from torchao.utils import fill_defaults
16-
from enum import Enum, auto
17-
from torchao.quantization.quant_api import _QUANTIZE_CONFIG_HANDLER
1818

1919

2020
class DummyModule(torch.nn.Module):
@@ -213,9 +213,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
213213
)
214214
raise e
215215

216+
216217
class UseFakeExtraDimTensor(Enum):
217-
"""Enum that indicate whether to use FakeExtraDimTensor
218-
"""
218+
"""Enum that indicate whether to use FakeExtraDimTensor"""
219+
219220
TRUE = auto()
220221
FALSE = auto()
221222
AS_FALLBACK = auto()
@@ -230,12 +231,13 @@ class MoEQuantConfig(AOBaseConfig):
230231

231232
base_config: AOBaseConfig
232233
use_fake_extra_dim_tensor: UseFakeExtraDimTensor = UseFakeExtraDimTensor.AS_FALLBACK
233-
set_inductor_config: bool=True
234+
set_inductor_config: bool = True
234235

235236

236237
# Module-level flag to track if we've already printed the error
237238
_moe_quant_tensor_has_printed_error = False
238239

240+
239241
def _moe_quant_tensor(weight, config):
240242
def _moe_quant_tensor_base(weight, config):
241243
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
@@ -250,13 +252,15 @@ def _moe_quant_tensor_fake_extra_dim_tensor(weight, config):
250252
# put tensors into modules since the handlers target modules not tensors
251253
dummy_modules = [DummyModule(tensor) for tensor in tensors]
252254
# apply handler to each module
253-
quant_mods = list(map(lambda x: base_config_handler(x, config.base_config), dummy_modules))
255+
quant_mods = list(
256+
map(lambda x: base_config_handler(x, config.base_config), dummy_modules)
257+
)
254258
# pack quantized subclasses into FakeExtraDimTensor
255259
quant_weight = FakeExtraDimTensor([mod.weight for mod in quant_mods])
256260
return quant_weight
257261

258262
global _moe_quant_tensor_has_printed_error
259-
263+
260264
use_fake = config.use_fake_extra_dim_tensor
261265
if use_fake == UseFakeExtraDimTensor.FALSE:
262266
return _moe_quant_tensor_base(weight, config)
@@ -272,7 +276,6 @@ def _moe_quant_tensor_fake_extra_dim_tensor(weight, config):
272276
return _moe_quant_tensor_fake_extra_dim_tensor(weight, config)
273277

274278

275-
276279
@register_quantize_module_handler(MoEQuantConfig)
277280
def moe_quant_fn(module, config: MoEQuantConfig):
278281
import warnings
@@ -283,7 +286,9 @@ def moe_quant_fn(module, config: MoEQuantConfig):
283286

284287
for weight_attr in ["w1", "w2", "w3"]:
285288
param = getattr(module, weight_attr)
286-
assert param.dim() == 3, f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}"
289+
assert param.dim() == 3, (
290+
f"when applying moe_quant to {module} expected 3D tensor for {weight_attr} but got {param.dim()}"
291+
)
287292
assert isinstance(config.base_config, AOBaseConfig), (
288293
f"MoEQuantConfig expected to be initialized with an AOBaseConfig but got {type(config.base_config)}"
289294
+ "this can happen if you initiaze with MoEQuantConfig(AOConfig) rather than MoEQuantConfig(AOConfig())"

0 commit comments

Comments
 (0)