|
22 | 22 | c10d_functional = torch.ops.c10d_functional
|
23 | 23 |
|
24 | 24 |
|
25 |
| -NF4_OPS_TABLE: Dict[Any, Any] = {} |
| 25 | +def nf4_all_gather_into_tensor(func, *args, **kwargs): |
| 26 | + assert len(args) > 1, "Expected valid input" |
| 27 | + assert len(args[0]) == 3, "Expected 3 input args" |
| 28 | + nf4tensor = args[0][0] |
| 29 | + group_size = args[0][1] |
| 30 | + name = args[0][2] |
| 31 | + updated_attrs = {} |
| 32 | + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: |
| 33 | + updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name) |
| 34 | + updated_attrs.update( |
| 35 | + { |
| 36 | + "size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])), |
| 37 | + } |
| 38 | + ) |
| 39 | + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) |
| 40 | + return updatedNF4Tensor |
| 41 | + |
| 42 | + |
| 43 | +def scatter_nf4tensor(func, *args, **kwargs): |
| 44 | + assert len(args) > 1, "Expected valid input" |
| 45 | + assert len(args[0][0]) == 1, "Expected 1 output tensor" |
| 46 | + output_tensor = args[0][0][0] |
| 47 | + input_tensors = args[0][1] |
| 48 | + new_attr, update_work = [], [] |
| 49 | + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: |
| 50 | + input_attrs = [] |
| 51 | + if input_tensors: |
| 52 | + for input_tensor in input_tensors[0]: |
| 53 | + assert input_tensor.size() == output_tensor.size(), ( |
| 54 | + "Input tensor size must match output tensor size, tensors are not evenly divided." |
| 55 | + ) |
| 56 | + if hasattr(input_tensor, attr): |
| 57 | + input_attrs.append(getattr(input_tensor, attr)) |
| 58 | + input_attrs = [input_attrs] |
| 59 | + new_attr, update_work = func( |
| 60 | + [getattr(output_tensor, attr)], input_attrs, *args[0][2:] |
| 61 | + ) |
| 62 | + # there are 3 works, return one of them, same as the tensor to fit the required output format |
| 63 | + return new_attr, update_work |
| 64 | + |
| 65 | + |
| 66 | +NF4_OPS_TABLE: Dict[Any, Any] = { |
| 67 | + torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor, |
| 68 | + torch.ops.c10d.scatter_.default: scatter_nf4tensor, |
| 69 | +} |
26 | 70 |
|
27 | 71 |
|
28 | 72 | _INNER_TENSOR_NAMES_FOR_SHARDING = [
|
@@ -233,7 +277,6 @@ def nf4_split(aten_op, args, kwargs=None):
|
233 | 277 | def nf4_new_zeros(aten_op, args, kwargs=None):
|
234 | 278 | nf4tensor = args[0]
|
235 | 279 | new_size = tuple(args[1])
|
236 |
| - |
237 | 280 | if nf4tensor.numel() % math.prod(new_size) != 0:
|
238 | 281 | raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}")
|
239 | 282 | ratio = nf4tensor.numel() // math.prod(new_size)
|
@@ -273,19 +316,37 @@ def nf4_slice(aten_op, args, kwargs=None):
|
273 | 316 | aten.view.default,
|
274 | 317 | ]
|
275 | 318 | )
|
276 |
| -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") |
| 319 | +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") |
277 | 320 | def nf4_view(aten_op, args, kwargs=None):
|
278 | 321 | nf4tensor = args[0]
|
279 | 322 | size = args[1]
|
280 |
| - if size[0] != -1: |
281 |
| - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") |
282 |
| - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) |
283 |
| - updated_attrs.update( |
284 |
| - { |
285 |
| - "size": [nf4tensor.numel()], |
286 |
| - "stride": (1,), |
287 |
| - } |
288 |
| - ) |
| 323 | + if len(size) == 1: |
| 324 | + if size[0] != -1: |
| 325 | + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") |
| 326 | + else: |
| 327 | + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) |
| 328 | + updated_attrs.update( |
| 329 | + { |
| 330 | + "size": [nf4tensor.numel()], |
| 331 | + "stride": (1,), |
| 332 | + } |
| 333 | + ) |
| 334 | + elif len(size) == 2: |
| 335 | + if nf4tensor.numel() != size[0] * size[1]: |
| 336 | + raise NotImplementedError("NF4Tensor size does not match view size.") |
| 337 | + updated_attrs = {} |
| 338 | + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: |
| 339 | + attr_size = [getattr(nf4tensor, attr).size()] |
| 340 | + updated_attrs[attr] = aten_op( |
| 341 | + getattr(nf4tensor, attr), *attr_size, **kwargs |
| 342 | + ) |
| 343 | + updated_attrs.update( |
| 344 | + { |
| 345 | + "stride": (size[1], 1), |
| 346 | + } |
| 347 | + ) |
| 348 | + else: |
| 349 | + raise NotImplementedError("aten.view(NF4Tensor) with empty size") |
289 | 350 | return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
|
290 | 351 |
|
291 | 352 |
|
@@ -457,6 +518,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
|
457 | 518 | return tensors
|
458 | 519 |
|
459 | 520 |
|
| 521 | +@implements( |
| 522 | + [ |
| 523 | + torch.ops._c10d_functional.wait_tensor.default, |
| 524 | + ] |
| 525 | +) |
| 526 | +def wait_tensor(func, *args, **kwargs): |
| 527 | + nf4tensor = args[0][0] |
| 528 | + updated_attrs = {} |
| 529 | + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: |
| 530 | + updated_attrs[attr] = func(getattr(nf4tensor, attr)) |
| 531 | + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) |
| 532 | + return updatedNF4Tensor |
| 533 | + |
| 534 | + |
460 | 535 | @dataclass(frozen=True)
|
461 | 536 | class SubclassTensorArgs:
|
462 | 537 | original_shape: torch.Size
|
|
0 commit comments