Skip to content

Commit 2ad5b5d

Browse files
Support collective matmul optimization in mp (#8855)
Co-authored-by: Yifei Teng <yifeit@google.com>
1 parent 760675a commit 2ad5b5d

17 files changed

+282
-63
lines changed

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@
6666

6767
USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
6868

69-
_date = '20250303'
70-
_libtpu_version = '0.0.11'
71-
_jax_version = '0.5.2'
72-
_jaxlib_version = '0.5.2'
69+
_date = '20250320'
70+
_libtpu_version = '0.0.12'
71+
_jax_version = '0.5.4'
72+
_jaxlib_version = '0.5.4'
7373

7474
_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
7575
_libtpu_storage_directory = 'libtpu-lts-releases'

test/test_mp_all_gather.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def _mp_fn(index):
1818
# Testing with a single replica group
1919
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
2020
result = xm.all_gather(ordinal_tensor, dim=0)
21+
xm.mark_step()
2122

2223
cpu_result = result.cpu()
2324
expected = torch.arange(0, world_size, dtype=torch.float)

test/test_mp_collective_matmul.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import sys
3+
import torch
4+
import torch_xla
5+
from torch_xla import runtime as xr
6+
import torch_xla.core.xla_model as xm
7+
8+
9+
def _mp_fn(index):
10+
os.environ["ENABLE_COLLECTIVE_MATMUL_IN_MP"] = "1"
11+
device = xm.xla_device()
12+
world_size = xr.world_size()
13+
groups = [[i for i in range(world_size)]]
14+
scale = 1 / world_size
15+
scatter_dim = 1
16+
shard_size = 2
17+
18+
if xm.xla_device_hw(device) in ('TPU',):
19+
# Testing with a single replica group, channel_id and use_global_device_ids
20+
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
21+
result = xm.all_gather(
22+
ordinal_tensor,
23+
dim=0,
24+
groups=groups,
25+
channel_id=1,
26+
use_global_device_ids=True)
27+
xm.mark_step()
28+
29+
cpu_result = result.cpu()
30+
expected = torch.arange(0, world_size, dtype=torch.float)
31+
assert cpu_result.allclose(expected)
32+
33+
rand = torch.rand((32, shard_size * world_size, 32))
34+
xrand = rand.to(device)
35+
36+
res = xm.reduce_scatter(
37+
xm.REDUCE_SUM,
38+
xrand,
39+
scale,
40+
scatter_dim,
41+
world_size,
42+
groups=groups,
43+
channel_id=1,
44+
use_global_device_ids=True)
45+
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand, scale)
46+
xm.mark_step()
47+
48+
slice_idx = torch.tensor(
49+
list(range(index * shard_size, (index + 1) * shard_size)))
50+
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
51+
52+
assert res.cpu().allclose(expected)
53+
54+
55+
if __name__ == '__main__':
56+
torch_xla.launch(_mp_fn, args=())

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python3 "$TEST_CDIR/test_operations.py" -v
1010
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"
1111
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py"
1212
python3 "$TEST_CDIR/spmd/test_mp_input_sharding.py"
13+
python3 "$TEST_CDIR/test_mp_collective_matmul.py"
1314
run_save_tensor_hlo python3 "$TEST_CDIR/spmd/test_spmd_lowering_context.py"
1415
python3 "$TEST_CDIR/spmd/test_xla_sharding.py"
1516
python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py"

torch_xla/_internal/tpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,22 @@ def version() -> int:
192192
return int(match.groups()[0])
193193

194194

195+
def get_tpu_type() -> str:
196+
"""
197+
Return the tpu type. E.g. "v6e-8" returns "v6e"
198+
"""
199+
try:
200+
env = get_tpu_env()
201+
except requests.HTTPError as e:
202+
raise EnvironmentError('Failed to get TPU metadata') from e
203+
204+
match = re.search(r"^([^-]*)-", env[xenv.ACCELERATOR_TYPE])
205+
if match:
206+
return match.group(1)
207+
else:
208+
return env[xenv.ACCELERATOR_TYPE]
209+
210+
195211
def get_worker_ips() -> List[str]:
196212
"""Returns ordered list of TPU worker IPs from TPU metadata."""
197213
if _using_env_vars():

torch_xla/core/xla_model.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch_xla.utils.utils as xu
2121
import torch_xla.utils.closures as xc
2222
from torch_xla.distributed.spmd.xla_sharding import ShardingSpec
23+
from torch_xla.distributed.xla_multiprocessing import create_optimized_replica_groups
2324
import os
2425
from torch_xla.experimental.deprecation import deprecated
2526
import torch_xla._internal.utils as _utils
@@ -532,7 +533,9 @@ def all_gather(value: torch.Tensor,
532533
dim: int = 0,
533534
groups: Optional[List[List[int]]] = None,
534535
output: Optional[torch.Tensor] = None,
535-
pin_layout: bool = True) -> torch.Tensor:
536+
pin_layout: bool = True,
537+
channel_id=None,
538+
use_global_device_ids=None) -> torch.Tensor:
536539
"""Performs an all-gather operation along a given dimension.
537540
538541
Args:
@@ -550,7 +553,8 @@ def all_gather(value: torch.Tensor,
550553
participate in the communication has slightly different program, but it might
551554
cause some xla compilation to fail. Unpin the layout when you see error message
552555
like "HloModule has a mix of layout constrained".
553-
556+
channel_id (int, optional): Optional channel ID for cross-module communication
557+
use_global_device_ids(bool, optional): If true, interprets ids in ReplicaGroup as global device ids
554558
Returns:
555559
A tensor which has, in the ``dim`` dimension, all the values from the
556560
participating replicas.
@@ -584,7 +588,8 @@ def all_gather(value: torch.Tensor,
584588
return output
585589

586590
result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
587-
[], pin_layout)
591+
[], pin_layout, channel_id,
592+
use_global_device_ids)
588593
return result
589594

590595
# Now the input should be a list of Tensors.
@@ -870,7 +875,9 @@ def reduce_scatter(reduce_type: str,
870875
groups: Optional[List[List[int]]] = None,
871876
output: Optional[Union[torch.Tensor,
872877
List[torch.Tensor]]] = None,
873-
pin_layout: bool = True) -> torch.Tensor:
878+
pin_layout: bool = True,
879+
channel_id=None,
880+
use_global_device_ids=None) -> torch.Tensor:
874881
"""Performs a XLA `ReduceScatter()` operation on the input tensor.
875882
876883
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
@@ -896,6 +903,8 @@ def reduce_scatter(reduce_type: str,
896903
participate in the communication has slightly different program, but it might
897904
cause some xla compilation to fail. Unpin the layout when you see error message
898905
like "HloModule has a mix of layout constrained".
906+
channel_id (int, optional): Optional channel ID for cross-module communication
907+
use_global_device_ids(bool, optional): If true, interprets ids in ReplicaGroup as global device ids
899908
900909
Returns:
901910
A `torch.Tensor` with all the values reduced across replicas. Each process
@@ -916,7 +925,8 @@ def reduce_scatter(reduce_type: str,
916925
result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
917926
scale, scatter_dim,
918927
shard_count, groups or [],
919-
pin_layout)
928+
pin_layout, channel_id,
929+
use_global_device_ids)
920930
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
921931
return result[0]
922932

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,20 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
232232
AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
233233
int64_t shard_count,
234234
const std::vector<std::vector<int64_t>>& groups,
235-
bool pin_layout) {
235+
bool pin_layout,
236+
std::optional<int64_t> channel_id,
237+
std::optional<bool> use_global_device_ids) {
236238
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
237239
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
238240
TokenHandler token_handler(token);
241+
std::optional<xla::ChannelHandle> channel_handle = std::nullopt;
242+
if (channel_id.has_value()) {
243+
xla::ChannelHandle channel_handle_value;
244+
channel_handle_value.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
245+
channel_handle_value.set_handle(channel_id.value());
246+
channel_handle = channel_handle_value;
247+
}
248+
239249
xla::XlaOp all_gather_result;
240250
if (pin_layout) {
241251
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
@@ -245,12 +255,13 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
245255
static_cast<XlaDeviceType>(xla_device.type()));
246256
all_gather_result =
247257
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
248-
shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
249-
/*layout=*/reduce_shape.layout());
258+
shard_count, reduce_groups, channel_handle,
259+
/*layout=*/reduce_shape.layout(), use_global_device_ids);
250260
} else {
251261
all_gather_result =
252262
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
253-
shard_count, reduce_groups);
263+
shard_count, reduce_groups, channel_handle,
264+
/*layout=*/std::nullopt, use_global_device_ids);
254265
}
255266
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
256267
}
@@ -389,10 +400,19 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
389400
ReduceScatterResult BuildReduceScatter(
390401
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
391402
int64_t scatter_dim, int64_t shard_count,
392-
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
403+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout,
404+
std::optional<int64_t> channel_id,
405+
std::optional<bool> use_global_device_ids) {
393406
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
394407
TokenHandler token_handler(token);
395408
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
409+
std::optional<xla::ChannelHandle> channel_handle = std::nullopt;
410+
if (channel_id.has_value()) {
411+
xla::ChannelHandle channel_handle_value;
412+
channel_handle_value.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
413+
channel_handle_value.set_handle(channel_id.value());
414+
channel_handle = channel_handle_value;
415+
}
396416
xla::XlaOp reduce_result;
397417
if (pin_layout) {
398418
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
@@ -403,13 +423,14 @@ ReduceScatterResult BuildReduceScatter(
403423
reduce_result = xla::ReduceScatter(
404424
token_handler.GetInput(input, &input_shape),
405425
GetReduceComutation(reduce_type, input_shape.element_type()),
406-
scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
407-
/*layout=*/reduce_shape.layout());
426+
scatter_dim, shard_count, reduce_groups, channel_handle,
427+
/*layout=*/reduce_shape.layout(), use_global_device_ids);
408428
} else {
409429
reduce_result = xla::ReduceScatter(
410430
token_handler.GetInput(input, &input_shape),
411431
GetReduceComutation(reduce_type, input_shape.element_type()),
412-
scatter_dim, shard_count, reduce_groups);
432+
scatter_dim, shard_count, reduce_groups, channel_handle,
433+
/*layout=*/std::nullopt, use_global_device_ids);
413434
}
414435

415436
if (scale != 1.0) {

torch_xla/csrc/cross_replica_reduces.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
7575
const std::vector<std::vector<int64_t>>& groups,
7676
bool pin_layout);
7777

78-
AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
79-
int64_t shard_count,
80-
const std::vector<std::vector<int64_t>>& groups,
81-
bool pin_layout);
78+
AllGatherResult BuildAllGather(
79+
xla::XlaOp input, xla::XlaOp token, int64_t dim, int64_t shard_count,
80+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout,
81+
std::optional<int64_t> channel_id = std::nullopt,
82+
std::optional<bool> use_global_device_ids = std::nullopt);
8283

8384
AllGatherResultCoalesced BuildAllGatherCoalesced(
8485
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
@@ -98,7 +99,9 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
9899
ReduceScatterResult BuildReduceScatter(
99100
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
100101
int64_t scatter_dim, int64_t shard_count,
101-
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);
102+
const std::vector<std::vector<int64_t>>& groups, bool pin_layout,
103+
std::optional<int64_t> channel_id = std::nullopt,
104+
std::optional<bool> use_global_device_ids = std::nullopt);
102105

103106
xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
104107
double scale, int64_t scatter_dim,

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -373,13 +373,16 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
373373
const std::string& reduce_type, const at::Tensor& input,
374374
const std::shared_ptr<torch::lazy::Value>& token, double scale,
375375
int64_t scatter_dim, int64_t shard_count,
376-
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
376+
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout,
377+
std::optional<int64_t> channel_id = std::nullopt,
378+
std::optional<bool> use_global_device_ids = std::nullopt) {
377379
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
378380
XLATensorPtr result;
379381
torch::lazy::Value new_token;
380382
std::tie(result, new_token) = tensor_methods::reduce_scatter(
381383
bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), scale,
382-
scatter_dim, shard_count, replica_groups, pin_layout);
384+
scatter_dim, shard_count, replica_groups, pin_layout, channel_id,
385+
use_global_device_ids);
383386
return std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>>(
384387
bridge::AtenFromXlaTensor(std::move(result)),
385388
std::make_shared<torch::lazy::Value>(new_token));
@@ -437,11 +440,13 @@ std::shared_ptr<torch::lazy::Value> ReduceScatterCoalescedOut(
437440

438441
at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count,
439442
const std::vector<std::vector<int64_t>>& replica_groups,
440-
bool pin_layout) {
443+
bool pin_layout,
444+
std::optional<int> channel_id = std::nullopt,
445+
std::optional<bool> use_global_device_ids = std::nullopt) {
441446
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
442-
auto result =
443-
tensor_methods::all_gather(bridge::GetXlaTensor(input), dim, shard_count,
444-
replica_groups, pin_layout);
447+
auto result = tensor_methods::all_gather(
448+
bridge::GetXlaTensor(input), dim, shard_count, replica_groups, pin_layout,
449+
channel_id, use_global_device_ids);
445450
return bridge::AtenFromXlaTensor(std::move(result));
446451
}
447452

@@ -1659,18 +1664,21 @@ void InitXlaModuleBindings(py::module m) {
16591664
result_tuple[1] = new_token;
16601665
return result_tuple;
16611666
});
1662-
m.def("_xla_all_gather", [](const at::Tensor& input, int64_t dim,
1663-
int64_t shard_count, const py::list& groups,
1664-
bool pin_layout) {
1665-
std::vector<std::vector<int64_t>> replica_groups =
1666-
CreateReduceGroups(groups);
1667-
at::Tensor result;
1668-
{
1669-
NoGilSection nogil;
1670-
result = AllGather(input, dim, shard_count, replica_groups, pin_layout);
1671-
}
1672-
return result;
1673-
});
1667+
m.def("_xla_all_gather",
1668+
[](const at::Tensor& input, int64_t dim, int64_t shard_count,
1669+
const py::list& groups, bool pin_layout,
1670+
std::optional<int> channel_id = std::nullopt,
1671+
std::optional<bool> use_global_device_ids = std::nullopt) {
1672+
std::vector<std::vector<int64_t>> replica_groups =
1673+
CreateReduceGroups(groups);
1674+
at::Tensor result;
1675+
{
1676+
NoGilSection nogil;
1677+
result = AllGather(input, dim, shard_count, replica_groups,
1678+
pin_layout, channel_id, use_global_device_ids);
1679+
}
1680+
return result;
1681+
});
16741682
m.def("_xla_all_gather_out",
16751683
[](at::Tensor& output, const at::Tensor& input,
16761684
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
@@ -1788,16 +1796,17 @@ void InitXlaModuleBindings(py::module m) {
17881796
[](const std::string& reduce_type, const at::Tensor& input,
17891797
const std::shared_ptr<torch::lazy::Value>& token, double scale,
17901798
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
1791-
bool pin_layout) {
1799+
bool pin_layout, std::optional<int64_t> channel_id = std::nullopt,
1800+
std::optional<bool> use_global_device_ids = std::nullopt) {
17921801
std::vector<std::vector<int64_t>> replica_groups =
17931802
CreateReduceGroups(groups);
17941803
at::Tensor result;
17951804
std::shared_ptr<torch::lazy::Value> new_token;
17961805
{
17971806
NoGilSection nogil;
1798-
std::tie(result, new_token) =
1799-
ReduceScatter(reduce_type, input, token, scale, scatter_dim,
1800-
shard_count, replica_groups, pin_layout);
1807+
std::tie(result, new_token) = ReduceScatter(
1808+
reduce_type, input, token, scale, scatter_dim, shard_count,
1809+
replica_groups, pin_layout, channel_id, use_global_device_ids);
18011810
}
18021811
auto result_tuple = py::tuple(2);
18031812
result_tuple[0] = torch::autograd::make_variable(

0 commit comments

Comments
 (0)