Skip to content

Commit 050b293

Browse files
committed
Add support for resharding for fbgemm configs
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit 050b293

File tree

7 files changed

+467
-54
lines changed

7 files changed

+467
-54
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,53 @@ def test_to_device(self):
146146
quantize_(linear, self.config)
147147
linear.to(device)
148148

149+
def test_cat(self):
150+
dtype = torch.bfloat16
151+
device = "cuda"
152+
# weight: (256, 128)
153+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
154+
# weight: (256, 128)
155+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
156+
157+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
158+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
159+
160+
dummy1.weight = torch.nn.Parameter(cat_weight1)
161+
quantize_(dummy1, self.config)
162+
163+
quantize_(linear1, self.config)
164+
quantize_(linear2, self.config)
165+
166+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
167+
self.assertTrue(cat_qweight1.shape, (512, 128))
168+
self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data)
169+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
170+
171+
# concat with dim == 1 is not really correct and will be fixed later
172+
# when we support distributed checkpointing
173+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
174+
self.assertTrue(cat_qweight2.shape, (256, 256))
175+
ref_float8_data = torch.cat([linear1.weight.float8_data, linear2.weight.float8_data], dim=1)
176+
ref_scale = linear1.weight.scale
177+
self.assertEqual(cat_qweight2.float8_data, ref_float8_data)
178+
self.assertEqual(cat_qweight2.scale, ref_scale)
179+
180+
181+
def test_transpose(self):
182+
dtype = torch.bfloat16
183+
device = "cuda"
184+
# weight: (256, 128)
185+
linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
186+
quantize_(linear1, self.config)
187+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
188+
linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device))
189+
self.assertTrue(linear1.weight.shape, (128, 256))
190+
191+
input = torch.randn(32, 256, dtype=dtype, device=device)
192+
# make sure it runs
193+
res = linear1(input)
194+
self.assertTrue(res.shape, (32, 128))
195+
149196

150197
if __name__ == "__main__":
151198
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,53 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
dtype = torch.bfloat16
157+
device = "cuda"
158+
# weight: (256, 128)
159+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
160+
# weight: (256, 128)
161+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
162+
163+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
164+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
165+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
166+
dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
167+
168+
dummy1.weight = torch.nn.Parameter(cat_weight1)
169+
dummy2.weight = torch.nn.Parameter(cat_weight2)
170+
quantize_(dummy1, self.config)
171+
quantize_(dummy2, self.config)
172+
173+
quantize_(linear1, self.config)
174+
quantize_(linear2, self.config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (512, 128))
178+
self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight)
179+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
180+
self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point)
181+
182+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
183+
self.assertTrue(cat_qweight2.shape, (256, 256))
184+
self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight)
185+
self.assertEqual(dummy2.weight.scale, cat_qweight2.scale)
186+
self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point)
187+
188+
def test_transpose(self):
189+
# weight: (256, 128)
190+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
191+
quantize_(linear1, self.config)
192+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
193+
# transpose again to return to the original state
194+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
195+
self.assertTrue(linear1.weight.shape, (256, 128))
196+
197+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
198+
# make sure it runs
199+
res = linear1(input)
200+
self.assertTrue(res.shape, (32, 256))
201+
155202

156203
if __name__ == "__main__":
157204
run_tests()

torchao/core/config.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@
55
# LICENSE file in the root directory of this source tree.
66
import abc
77
import dataclasses
8+
from dataclasses import dataclass
89
import enum
910
import importlib
1011
import json
1112
from typing import Any, ClassVar, Dict
1213

1314
import torch
1415

16+
__all__ = [
17+
"AOBaseConfig",
18+
"VersionMismatchError",
19+
"config_to_dict",
20+
"config_from_dict",
21+
"ALLOWED_AO_MODULES",
22+
"e4m3_dtype",
23+
"e5m2_dtype",
24+
]
25+
1526

1627
class AOBaseConfig(abc.ABC):
1728
"""
@@ -284,3 +295,28 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
284295
return cls(**processed_data)
285296
except Exception as e:
286297
raise ValueError(f"Failed to create instance of {cls.__name__}: {e}")
298+
299+
@dataclass
300+
class Float8TypeConfig:
301+
"""
302+
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
303+
304+
Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4.
305+
"""
306+
307+
# The preferred e4m3 type.
308+
e4m3_dtype = torch.float8_e4m3fn
309+
310+
# The preferred e5m2 type.
311+
e5m2_dtype = torch.float8_e5m2
312+
313+
def __post_init__(self):
314+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
315+
self.e4m3_dtype = torch.float8_e4m3fnuz
316+
self.e5m2_dtype = torch.float8_e5m2fnuz
317+
318+
319+
# User defined type for using the individual F8 type based on config
320+
type_config = Float8TypeConfig()
321+
e4m3_dtype = type_config.e4m3_dtype
322+
e5m2_dtype = type_config.e5m2_dtype

0 commit comments

Comments
 (0)