Skip to content

Commit 6d88c08

Browse files
authored
Use shard_as in scan to ensure that inputs and their gradients have the same sharding (#8879)
1 parent fe3bb7f commit 6d88c08

File tree

12 files changed

+281
-59
lines changed

12 files changed

+281
-59
lines changed

.circleci/common.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ function install_post_deps_pytorch_xla() {
116116
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
117117
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
118118

119+
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
119120
pip install xla/torchax
120121
}
121122

.github/workflows/_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ jobs:
133133
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
134134
135135
# Install torchax
136+
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
136137
pip install pytorch/xla/torchax
137138
138139
if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then

.github/workflows/_tpu_ci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ jobs:
2727
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
2828
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
2929
pip install --upgrade protobuf
30+
31+
# torchax is needed for call_jax tests.
32+
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
33+
pip install pytorch/xla/torchax
3034
- name: Run Tests
3135
env:
3236
PJRT_DEVICE: TPU

test/scan/test_scan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,9 @@ def unpack(x):
475475

476476
# Find the input that is stored in the context object.
477477
stored_xs = None
478-
for s in storage:
478+
# Dedupe the tensors because the autograd context may save the same tensor twice.
479+
# Saving a tensor twice won't use extra storage though thanks to ref-counting.
480+
for s in set(storage):
479481
if s.shape == xs.shape:
480482
assert stored_xs is None
481483
stored_xs = s

test/scan/test_scan_spmd.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from copy import deepcopy
21
import sys
32
import re
43
import unittest
54

5+
import numpy as np
66
import torch
77
import torch_xla
88
import torch.nn as nn
9-
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear
9+
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
1010
from torch_xla.experimental.scan import scan
1111
from torch_xla.experimental.scan_layers import scan_layers
12-
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh
12+
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
1313
import torch_xla.runtime as xr
1414

1515

@@ -59,6 +59,98 @@ def fn(carry, x):
5959
f'devices=[1,{N}]0,',
6060
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
6161

62+
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
63+
"Multiple devices required")
64+
def test_scan_2d_sharding(self):
65+
"""
66+
Test the sharding propagation of gradients when scanning 2D sharded layers.
67+
68+
Specifically, we scan over a group of simple MLP blocks found in transformers.
69+
70+
Inputs:
71+
A: [B_x, S, H_y]
72+
W1: [F_y, H_x]
73+
W2: [H_x, F_y]
74+
75+
Outputs:
76+
B: [B_x, S, H_y]
77+
78+
B = A @ W1.T @ W2.T
79+
80+
Under 2D sharding, the gradient of loss w.r.t. (A @ W1.T) is 2D sharded.
81+
A is also 2D sharded. GSPMD need to figure out that the gradient of loss w.r.t.
82+
W1 should also be 2D sharded.
83+
"""
84+
85+
mesh_shape = (2, xr.global_runtime_device_count() // 2)
86+
mesh_axis_names = ("fsdp", "tensor")
87+
mesh = Mesh(
88+
np.arange(xr.global_runtime_device_count()), mesh_shape,
89+
mesh_axis_names)
90+
91+
class MLPBlock(nn.Module):
92+
93+
def __init__(self):
94+
super().__init__()
95+
self.up_proj = nn.Linear(128, 256, bias=False)
96+
self.down_proj = nn.Linear(256, 128, bias=False)
97+
98+
def forward(self, hidden_states):
99+
hidden_states = mark_sharding_with_gradients(hidden_states, mesh,
100+
("fsdp", None, "tensor"))
101+
hidden_states = self.up_proj(hidden_states)
102+
hidden_states = mark_sharding_with_gradients(hidden_states, mesh,
103+
("fsdp", None, "tensor"))
104+
hidden_states = torch.sin(hidden_states)
105+
hidden_states = mark_sharding_with_gradients(hidden_states, mesh,
106+
("fsdp", None, "tensor"))
107+
hidden_states = self.down_proj(hidden_states)
108+
hidden_states = mark_sharding_with_gradients(hidden_states, mesh,
109+
("fsdp", None, "tensor"))
110+
return hidden_states
111+
112+
class MyModel(nn.Module):
113+
114+
def __init__(self):
115+
super().__init__()
116+
self.layers = nn.Sequential(*[MLPBlock() for _ in range(4)])
117+
118+
def forward(self, hidden_states: torch.Tensor):
119+
hidden_states = mark_sharding_with_gradients(hidden_states, mesh,
120+
("fsdp", None, "tensor"))
121+
return scan_layers(self.layers, hidden_states)
122+
123+
torch.manual_seed(42)
124+
torch_xla.manual_seed(42)
125+
model = MyModel().to('xla')
126+
model = apply_xla_patch_to_nn_linear(model)
127+
for name, param in model.named_parameters():
128+
if 'up_proj' in name:
129+
mark_sharding(param, mesh, ("tensor", "fsdp"))
130+
if 'down_proj' in name:
131+
mark_sharding(param, mesh, ("fsdp", "tensor"))
132+
133+
# Batch, Seq, Hidden
134+
hidden_states = torch.randn((3, 50, 128), device='xla')
135+
torch_xla.sync()
136+
137+
# Run the model
138+
model.zero_grad()
139+
out = model(hidden_states)
140+
# Prepare to check the gradient of W1
141+
for layer in model.layers.children(): # type: ignore
142+
layer.up_proj.weight.retain_grad() # type: ignore
143+
out.sum().backward()
144+
torch_xla.sync(wait=True)
145+
# Check the gradient of W1
146+
for layer in model.layers.children(): # type: ignore
147+
# Right: {devices=[2,2]0,2,1,3}, {devices=[4,2]0,4,1,5,2,6,3,7} or similar
148+
# Wrong: {devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} or similar
149+
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(
150+
layer.up_proj.weight.grad) # type: ignore
151+
self.assertIn(f'devices=[{mesh_shape[1]},2]0', sharding_spec)
152+
self.assertNotIn('last_tile_dim_replicate', sharding_spec)
153+
62154
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
63155
"Multiple devices required")
64156
def test_scan_xla_patched_linear(self):

test/spmd/test_xla_sharding.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,29 @@ def test_get_logical_mesh(self):
16601660
self.assertEqual(logical_mesh.shape, mesh_shape)
16611661
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16621662

1663+
@unittest.skipIf(xr.device_type() == 'CPU',
1664+
"sharding will be the same for both tensors on single device"
1665+
)
1666+
def test_shard_as(self):
1667+
mesh = self._get_mesh((self.n_devices,))
1668+
partition_spec = (0,)
1669+
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
1670+
dtype=torch.float,
1671+
device=xm.xla_device())
1672+
x = xs.mark_sharding_with_gradients(x, mesh, partition_spec)
1673+
y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
1674+
dtype=torch.float,
1675+
device=xm.xla_device())
1676+
1677+
x, y = xs.shard_as(x, y)
1678+
torch_xla.sync()
1679+
1680+
sharding_spec = '{devices=[%d]' % self.n_devices
1681+
x_sharding = torch_xla._XLAC._get_xla_sharding_spec(x)
1682+
y_sharding = torch_xla._XLAC._get_xla_sharding_spec(y)
1683+
self.assertIn(sharding_spec, x_sharding)
1684+
self.assertEqual(x_sharding, y_sharding)
1685+
16631686

16641687
if __name__ == '__main__':
16651688
test = unittest.main()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
from contextlib import contextmanager
3+
from typing import Callable, Any
4+
import functools
5+
6+
7+
# TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack.
8+
def jax_import_guard():
9+
import torch_xla
10+
# Somehow, we need to grab the TPU before JAX locks it. Otherwise, any pt-xla TPU operations will hang.
11+
torch_xla._XLAC._init_computation_client()
12+
13+
14+
# TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack.
15+
def requires_jax(func: Callable) -> Callable:
16+
"""Decorator that ensures JAX is safely imported before function execution"""
17+
18+
@functools.wraps(func)
19+
def wrapper(*args, **kwargs) -> Any:
20+
try:
21+
jax_import_guard()
22+
except ImportError as e:
23+
raise ImportError(
24+
"JAX import guard fail due to PJRT client is unavailable.") from e
25+
with jax_env_context():
26+
return func(*args, **kwargs)
27+
28+
return wrapper
29+
30+
31+
# TODO(b/374631442): Get rid of this hack that worksaround MegaScale hanging.
32+
@contextmanager
33+
def jax_env_context():
34+
previous_skip_megascale_env = None
35+
try:
36+
previous_skip_megascale_env = os.environ.get('SKIP_MEGASCALE_PJRT_CLIENT',
37+
None)
38+
os.environ['SKIP_MEGASCALE_PJRT_CLIENT'] = 'true'
39+
yield
40+
finally:
41+
if previous_skip_megascale_env:
42+
os.environ['SKIP_MEGASCALE_PJRT_CLIENT'] = previous_skip_megascale_env
43+
else:
44+
os.environ.pop('SKIP_MEGASCALE_PJRT_CLIENT', None)

torch_xla/core/xla_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from weakref import WeakKeyDictionary
44
import torch
55
import torch_xla
6-
from torch.utils._pytree import tree_flatten, tree_unflatten
7-
from torch_xla.experimental.custom_kernel import _jax_env_context, jax_import_guard
6+
from torch.utils._pytree import tree_flatten
7+
from torch_xla._internal.jax_workarounds import jax_env_context, jax_import_guard
88

99

1010
class Type:
@@ -862,7 +862,7 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):
862862

863863
# Prevent JAX from discovering MegaScale devices a second time. If we don't do this,
864864
# then the MegaScale device discovery will hang.
865-
with _jax_env_context():
865+
with jax_env_context():
866866
import jax
867867
import torchax.ops.mappings as mappings
868868

torch_xla/distributed/spmd/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mark_sharding, mark_sharding_with_gradients, clear_sharding, get_1d_mesh,
55
wrap_if_sharded, xla_patched_nn_linear_forward, set_global_mesh,
66
get_global_mesh, _mark_manual_sharding, enable_manual_sharding,
7-
disable_manual_sharding, apply_backward_optimization_barrier)
7+
disable_manual_sharding, apply_backward_optimization_barrier, shard_as)
88
from .api import xla_distribute_tensor, xla_distribute_module, auto_policy
99

1010
__all__ = [
@@ -19,6 +19,7 @@
1919
"MarkShardingFunction"
2020
"mark_sharding",
2121
"mark_sharding_with_gradients",
22+
"shard_as",
2223
"clear_sharding",
2324
"get_1d_mesh",
2425
"wrap_if_sharded",

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@
77
from torch import Tensor
88
from torch.library import custom_op
99
import torch_xla
10+
import torch_xla.core.xla_builder as xb
1011
import torch_xla.core.xla_model as xm
1112
import torch_xla._internal.utils as _utils
1213
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
1314
import torch_xla.runtime as xr
1415
import torch_xla.debug.profiler as xp
16+
from torch_xla._internal.jax_workarounds import requires_jax
1517

1618
import numpy as np
1719
import functools
1820
import itertools
19-
from typing import Union, Sequence, Any, Optional
21+
from typing import TypeVar, Union, Any, Optional
22+
from collections.abc import Sequence
2023
from enum import IntEnum
2124

2225
from torch.amp import custom_fwd, custom_bwd
26+
from torch.utils._pytree import tree_flatten, tree_unflatten
2327

2428
PartitionSpec = tuple[Union[tuple[Union[int, str], ...], int, str, None], ...]
2529
"""PartitionSpec describes the sharding of a tensor.
@@ -574,7 +578,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
574578
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
575579
"""
576580
# We only allow fully specified `partition_spec` to be applicable, as opposed
577-
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
581+
# to filling in the unspecified replicated dims. Fully specified `partition_spec`
578582
# should be of the same rank as `t`. This is to support partial replication
579583
# where the group assignment may vary with different input ranks.
580584
assert len(t.shape) == len(partition_spec), \
@@ -588,8 +592,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
588592

589593
def mark_sharding_with_gradients(
590594
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
591-
partition_spec: tuple[Union[tuple, int, str, None],
592-
...]) -> XLAShardedTensor:
595+
partition_spec: tuple[Union[tuple, int, str, None], ...]) -> torch.Tensor:
593596
"""
594597
A function to add sharding annotations on intermediate tensors (not in-place) and the gradient
595598
of the intermediate tensors during backward pass.
@@ -618,13 +621,48 @@ def mark_sharding_with_gradients(
618621
This version can also be used in AOTAutograd.
619622
"""
620623
# We only allow fully specified `partition_spec` to be applicable, as opposed
621-
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
624+
# to filling in the unspecified replicated dims. Fully specified `partition_spec`
622625
# should be of the same rank as `t`. This is to support partial replication
623626
# where the group assignment may vary with different input ranks.
624627
assert len(t.shape) == len(partition_spec), \
625628
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
626629

627-
return MarkShardingFunction.apply(t, mesh, partition_spec)
630+
r = MarkShardingFunction.apply(t, mesh, partition_spec)
631+
assert isinstance(r, torch.Tensor)
632+
return r
633+
634+
635+
PyTreeA = TypeVar('PyTreeA')
636+
PyTreeB = TypeVar('PyTreeB')
637+
638+
639+
@requires_jax
640+
def shard_as(a: PyTreeA, b: PyTreeB) -> tuple[PyTreeA, PyTreeB]:
641+
"""Ensure that `a` and `b` are sharded the same way without specifying
642+
a particular sharding constraint.
643+
644+
shard_as takes two PyTrees of matching structure and returns
645+
two PyTrees of that same structure. As long as you use at least one
646+
of the outputs, then corresponding tensors in all four PyTrees
647+
(a, b, out[0], out[1]) will be sharded the same way.
648+
"""
649+
650+
a_flat, a_spec = tree_flatten(a)
651+
b_flat, b_spec = tree_flatten(b)
652+
assert a_spec == b_spec, f"a and b must have the same structure. got {a_spec} and {b_spec}"
653+
a_sharded_flat = []
654+
b_sharded_flat = []
655+
from jax.experimental.shard_alike import shard_alike
656+
for x, y in zip(a_flat, b_flat):
657+
if x is None or y is None:
658+
# If there are None leaves, then it should be None in both PyTrees.
659+
assert x is None and y is None
660+
else:
661+
x, y = xb.call_jax(shard_alike, (x, y))
662+
a_sharded_flat.append(x)
663+
b_sharded_flat.append(y)
664+
return tree_unflatten(a_sharded_flat,
665+
a_spec), tree_unflatten(b_sharded_flat, b_spec)
628666

629667

630668
def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:

0 commit comments

Comments
 (0)