Skip to content

Commit 1acc987

Browse files
authored
Add autograd function for mark_sharding (#8723)
1 parent a08b7f8 commit 1acc987

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.debug.metrics as met
1818
import torch_xla.distributed.spmd as xs
1919
from torch_xla.distributed.spmd import XLAShardedTensor
20+
from torch_xla.distributed.spmd.xla_sharding import MarkShardingFunction
2021
import torch_xla.distributed.parallel_loader as pl
2122
import test_xla_sharding_base
2223

@@ -835,6 +836,23 @@ def test_mark_sharding_ir(self):
835836

836837
self.assertTrue(torch.allclose(expected, actual.cpu()))
837838

839+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
840+
"Multiple devices required for autograd sharding test")
841+
def test_mark_sharding_autograd(self):
842+
x = torch.randn(8, 8, requires_grad=True)
843+
x = x.to('xla')
844+
mesh = self._get_mesh((1, self.n_devices))
845+
# Forward pass
846+
z = x @ x
847+
z.retain_grad() # To be able to extract HLO from intermediate tensor grad.
848+
y = MarkShardingFunction.apply(z, mesh, (0, 1))
849+
t = y.sum()
850+
# Backward pass
851+
t.backward()
852+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([z.grad])
853+
sharding_annotation = 'sharding={devices=[1,%d]' % self.n_devices
854+
self.assertIn(sharding_annotation, hlo)
855+
838856
def test_sharded_tensor_aliasing(self):
839857
met.clear_all()
840858
partition_spec = (0, 1)

torch_xla/distributed/spmd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"ShardingType",
1818
"ShardingSpec",
1919
"XLAPatchedLinear",
20+
"MarkShardingFunction"
2021
"mark_sharding",
2122
"clear_sharding",
2223
"get_1d_mesh",

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,3 +1123,32 @@ def _generate_logical_mesh(
11231123
logical_mesh_shape) # type: ignore # numpy 2.2
11241124

11251125
return logical_mesh
1126+
1127+
1128+
class MarkShardingFunction(torch.autograd.Function):
1129+
"""
1130+
Autograd function to mark_sharding on intermediate tensors and the gradient
1131+
of the intermediate tensors during backward pass.
1132+
1133+
Usage:
1134+
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
1135+
1136+
This is required to guide GSPMD sharding propagation better during the
1137+
backward pass as during complicated workloads the compiler can introduce extra
1138+
collectives that can hurt performance.
1139+
"""
1140+
1141+
@staticmethod
1142+
def forward(ctx, torch_tensor: torch.Tensor, mesh: Mesh,
1143+
partition_spec: Tuple) -> torch.Tensor:
1144+
mark_sharding(torch_tensor, mesh, partition_spec)
1145+
ctx.partition_spec = partition_spec
1146+
ctx.mesh = mesh
1147+
return torch_tensor
1148+
1149+
@staticmethod
1150+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
1151+
partition_spec = ctx.partition_spec
1152+
mesh = ctx.mesh
1153+
mark_sharding(grad_output, mesh, partition_spec)
1154+
return grad_output, None, None

0 commit comments

Comments
 (0)