Skip to content

Commit c2a6568

Browse files
authored
[NF4] Support nf4 tensor shard and gather (#2449)
* support nf4 tensor shard and gather
1 parent bc30c2a commit c2a6568

File tree

2 files changed

+134
-18
lines changed

2 files changed

+134
-18
lines changed

test/dtypes/test_nf4.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,19 +435,24 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
435435
inner_tensor = getattr(viewed_tensor, attr)
436436
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())
437437

438-
@parametrize("input_size", [(512 * 512,), (512, 512)])
438+
@parametrize("input_size", [(512, 512)])
439+
def test_tensor_2d_view_valid(self, input_size: Tuple[int]):
440+
nf4_tensor = to_nf4(torch.randn(input_size))
441+
viewed_tensor = nf4_tensor.view(input_size)
442+
self.assertEqual(viewed_tensor.dim(), 2)
443+
self.assertEqual(viewed_tensor.numel(), math.prod(input_size))
444+
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
445+
inner_tensor = getattr(viewed_tensor, attr)
446+
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())
447+
448+
@parametrize("input_size", [(512 * 512,)])
439449
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
440450
nf4_tensor = to_nf4(torch.randn(input_size))
441451
if len(input_size) == 1:
442452
with self.assertRaisesRegex(
443453
NotImplementedError, "aten.view\\(NF4Tensor\\) with size"
444454
):
445455
nf4_tensor.view(input_size)
446-
if len(input_size) == 2:
447-
with self.assertRaisesRegex(
448-
NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"
449-
):
450-
nf4_tensor.view(input_size)
451456

452457
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
453458
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
@@ -741,6 +746,42 @@ def _test_qlora_fsdp2(
741746
self.assertEqual(fsdp_loss, base_loss)
742747

743748

749+
class TestComm(FSDPTest):
750+
@property
751+
def world_size(self) -> int:
752+
return 2
753+
754+
@skip_if_lt_x_gpu(2)
755+
def test_comm(self):
756+
self.run_subtests(
757+
{"input_size": [512, 2048]},
758+
self._test_comm,
759+
)
760+
761+
def _test_comm(self, input_size: int):
762+
from torch.distributed._composable.fsdp import fully_shard
763+
from torch.distributed._tensor import distribute_tensor
764+
765+
model = nn.Linear(input_size, input_size, device="cuda")
766+
origin_tensor = model.weight
767+
origin_nf4_tensor = to_nf4(origin_tensor)
768+
model = fully_shard(model)
769+
sharded_tensor = model.weight
770+
sharded_origin_nf4_tensor = distribute_tensor(
771+
origin_nf4_tensor,
772+
sharded_tensor.device_mesh,
773+
sharded_tensor.placements,
774+
)
775+
776+
sharded_nf4_detach = sharded_origin_nf4_tensor.detach()
777+
resumed_full_tensor = sharded_nf4_detach.full_tensor()
778+
779+
self.assertEqual(
780+
origin_nf4_tensor.get_original_weight(),
781+
resumed_full_tensor.get_original_weight(),
782+
)
783+
784+
744785
instantiate_parametrized_tests(TestNF4Linear)
745786
instantiate_parametrized_tests(TestFSDPOps)
746787

torchao/dtypes/nf4tensor.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,51 @@
2222
c10d_functional = torch.ops.c10d_functional
2323

2424

25-
NF4_OPS_TABLE: Dict[Any, Any] = {}
25+
def nf4_all_gather_into_tensor(func, *args, **kwargs):
26+
assert len(args) > 1, "Expected valid input"
27+
assert len(args[0]) == 3, "Expected 3 input args"
28+
nf4tensor = args[0][0]
29+
group_size = args[0][1]
30+
name = args[0][2]
31+
updated_attrs = {}
32+
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
33+
updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name)
34+
updated_attrs.update(
35+
{
36+
"size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])),
37+
}
38+
)
39+
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
40+
return updatedNF4Tensor
41+
42+
43+
def scatter_nf4tensor(func, *args, **kwargs):
44+
assert len(args) > 1, "Expected valid input"
45+
assert len(args[0][0]) == 1, "Expected 1 output tensor"
46+
output_tensor = args[0][0][0]
47+
input_tensors = args[0][1]
48+
new_attr, update_work = [], []
49+
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
50+
input_attrs = []
51+
if input_tensors:
52+
for input_tensor in input_tensors[0]:
53+
assert input_tensor.size() == output_tensor.size(), (
54+
"Input tensor size must match output tensor size, tensors are not evenly divided."
55+
)
56+
if hasattr(input_tensor, attr):
57+
input_attrs.append(getattr(input_tensor, attr))
58+
input_attrs = [input_attrs]
59+
new_attr, update_work = func(
60+
[getattr(output_tensor, attr)], input_attrs, *args[0][2:]
61+
)
62+
# there are 3 works, return one of them, same as the tensor to fit the required output format
63+
return new_attr, update_work
64+
65+
66+
NF4_OPS_TABLE: Dict[Any, Any] = {
67+
torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor,
68+
torch.ops.c10d.scatter_.default: scatter_nf4tensor,
69+
}
2670

2771

2872
_INNER_TENSOR_NAMES_FOR_SHARDING = [
@@ -233,7 +277,6 @@ def nf4_split(aten_op, args, kwargs=None):
233277
def nf4_new_zeros(aten_op, args, kwargs=None):
234278
nf4tensor = args[0]
235279
new_size = tuple(args[1])
236-
237280
if nf4tensor.numel() % math.prod(new_size) != 0:
238281
raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}")
239282
ratio = nf4tensor.numel() // math.prod(new_size)
@@ -273,19 +316,37 @@ def nf4_slice(aten_op, args, kwargs=None):
273316
aten.view.default,
274317
]
275318
)
276-
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=")
319+
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=")
277320
def nf4_view(aten_op, args, kwargs=None):
278321
nf4tensor = args[0]
279322
size = args[1]
280-
if size[0] != -1:
281-
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
282-
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
283-
updated_attrs.update(
284-
{
285-
"size": [nf4tensor.numel()],
286-
"stride": (1,),
287-
}
288-
)
323+
if len(size) == 1:
324+
if size[0] != -1:
325+
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
326+
else:
327+
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
328+
updated_attrs.update(
329+
{
330+
"size": [nf4tensor.numel()],
331+
"stride": (1,),
332+
}
333+
)
334+
elif len(size) == 2:
335+
if nf4tensor.numel() != size[0] * size[1]:
336+
raise NotImplementedError("NF4Tensor size does not match view size.")
337+
updated_attrs = {}
338+
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
339+
attr_size = [getattr(nf4tensor, attr).size()]
340+
updated_attrs[attr] = aten_op(
341+
getattr(nf4tensor, attr), *attr_size, **kwargs
342+
)
343+
updated_attrs.update(
344+
{
345+
"stride": (size[1], 1),
346+
}
347+
)
348+
else:
349+
raise NotImplementedError("aten.view(NF4Tensor) with empty size")
289350
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
290351

291352

@@ -457,6 +518,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
457518
return tensors
458519

459520

521+
@implements(
522+
[
523+
torch.ops._c10d_functional.wait_tensor.default,
524+
]
525+
)
526+
def wait_tensor(func, *args, **kwargs):
527+
nf4tensor = args[0][0]
528+
updated_attrs = {}
529+
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
530+
updated_attrs[attr] = func(getattr(nf4tensor, attr))
531+
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
532+
return updatedNF4Tensor
533+
534+
460535
@dataclass(frozen=True)
461536
class SubclassTensorArgs:
462537
original_shape: torch.Size

0 commit comments

Comments
 (0)