Skip to content

Commit 483c9e1

Browse files
authored
新增API local_map (#71804)
* 新增API local_map * 修正文件格式 * 优化了local_map一些功能 * 新增reshard功能,同时兼容动静态下的local_map调用 * 修改格式规范 * 修正单测的接口命名 * 用local_map替换LocalLayer * 单测使用local_map时的参数设置修改,reshard设置为True * 修正单测 * 修改单测 * 修改格式 * 修改格式 * 修改测试样例格式 * 修改测试样例格式
1 parent 9b7cd3d commit 483c9e1

File tree

9 files changed

+653
-144
lines changed

9 files changed

+653
-144
lines changed

python/paddle/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
SequenceParallelEnd,
6161
)
6262
from .auto_parallel.local_layer import LocalLayer
63+
from .auto_parallel.local_map import local_map
6364
from .auto_parallel.placement_type import (
6465
Partial,
6566
Replicate,
@@ -192,6 +193,7 @@
192193
"Strategy",
193194
"DistModel",
194195
"LocalLayer",
196+
"local_map",
195197
"unshard_dtensor",
196198
"parallelize",
197199
"SequenceParallelEnd",

python/paddle/distributed/auto_parallel/api.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,29 @@ def replicate_layer_params_and_buffers(
10611061
)
10621062

10631063

1064+
def is_dist_tensor(tensor) -> bool:
1065+
"""
1066+
Check if an input is a dist_tensor in both dynamic and static modes.
1067+
1068+
Args:
1069+
tensor: The input to check
1070+
1071+
Returns:
1072+
bool: True if the input is a dist_tensor, False otherwise
1073+
"""
1074+
if paddle.in_dynamic_mode():
1075+
return (
1076+
isinstance(tensor, paddle.Tensor)
1077+
and hasattr(tensor, 'is_dist')
1078+
and tensor.is_dist()
1079+
)
1080+
else:
1081+
return (
1082+
isinstance(tensor, paddle.base.libpaddle.pir.Value)
1083+
and tensor.dist_attr() is not None
1084+
)
1085+
1086+
10641087
class _ShardOptimizer(Optimizer):
10651088
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
10661089
assert (
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import functools
17+
from typing import TYPE_CHECKING, Any, Callable
18+
19+
import paddle
20+
import paddle.distributed as dist
21+
from paddle.utils import flatten, pack_sequence_as
22+
23+
if TYPE_CHECKING:
24+
from paddle.distributed import ProcessMesh
25+
26+
27+
def local_map(
28+
func: Callable[..., Any],
29+
out_placements: list[list[dist.Placement]],
30+
in_placements: list[list[dist.Placement]] | None = None,
31+
process_mesh: ProcessMesh | None = None,
32+
reshard_inputs: bool = False,
33+
) -> Callable[..., Any]:
34+
"""
35+
The `local_map` API allows users to pass dist_tensors to a function that is written
36+
to be applied on ``paddle.Tensor`` s. It works by extracting the local components
37+
of dist_tensors, calling the function, and wrapping the outputs as dist_tensors
38+
according to the ``out_placements``.
39+
40+
Args:
41+
func (Callable): The function to be applied on each local shard of dist_tensors.
42+
43+
out_placements (list[list[dist.Placement]]):
44+
The desired placements for each output tensor. Must be a list where each element
45+
is a list of Placement objects specifying the distribution strategy for that
46+
output tensor. The length of the outer list must match the number of outputs
47+
from ``func``. For non-tensor outputs, the corresponding placement must be None.
48+
When there are no dist_tensor inputs, process_mesh must be specified to use
49+
non-None placements.
50+
51+
in_placements (Optional[list[list[dist.Placement]]], optional):
52+
The required placements for each input tensor. If specified, must be a list
53+
where each element is a list of Placement objects defining the distribution
54+
strategy for that input tensor. The length of the outer list must match the
55+
number of input tensors.
56+
Default: None
57+
58+
process_mesh (ProcessMesh, optional):
59+
The process mesh that all dist_tensors are placed on. If not specified,
60+
this will be inferred from the input dist_tensors' process mesh.
61+
local_map requires all dist_tensors to be placed on the same process mesh.
62+
Must be specified when there are no dist_tensor inputs but out_placements
63+
contains non-None values.
64+
Default: None
65+
66+
reshard_inputs (bool, optional):
67+
the bool value indicating whether to reshard the input :dist_tensors when
68+
their placements are different from the required input placements. If this
69+
value is ``False`` and some :dist_tensor input has a different placement,
70+
an exception will be raised. Default: False.
71+
72+
Returns:
73+
Callable: A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values.
74+
75+
Example:
76+
.. code-block:: python
77+
78+
>>> from __future__ import annotations
79+
>>> import paddle
80+
>>> import paddle.distributed as dist
81+
>>> from paddle import Tensor
82+
>>> from paddle.distributed import ProcessMesh
83+
84+
>>> def custom_function(x):
85+
... mask = paddle.zeros_like(x)
86+
... if dist.get_rank() == 0:
87+
... mask[1:3] = 1
88+
... else:
89+
... mask[4:7] = 1
90+
... x = x * mask
91+
... mask_sum = paddle.sum(x)
92+
... mask_sum = mask_sum / mask.sum()
93+
... return mask_sum
94+
95+
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
96+
>>> dist.init_parallel_env()
97+
>>> mesh = ProcessMesh([0, 1], dim_names=["x"])
98+
>>> local_input = paddle.arange(0, 10, dtype="float32")
99+
>>> local_input = local_input + dist.get_rank()
100+
>>> input_dist = dist.auto_parallel.api.dtensor_from_local(
101+
... local_input, mesh, [dist.Shard(0)]
102+
... )
103+
>>> wrapped_func = dist.local_map(
104+
... custom_function,
105+
... out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
106+
... in_placements=[[dist.Shard(0)]],
107+
... process_mesh=mesh
108+
... )
109+
>>> output_dist = wrapped_func(input_dist)
110+
111+
>>> local_value = output_dist._local_value()
112+
>>> gathered_values: list[Tensor] = []
113+
>>> dist.all_gather(gathered_values, local_value)
114+
115+
>>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
116+
[Rank 0] local_value=1.5
117+
>>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
118+
[Rank 1] local_value=6.0
119+
>>> print(f"global_value (distributed)={output_dist.item()}")
120+
global_value (distributed)=7.5
121+
122+
>>> # This case needs to be executed in a multi-card environment
123+
>>> # export CUDA_VISIBLE_DEVICES=0,1
124+
>>> # python -m paddle.distributed.launch {test_case}.py
125+
"""
126+
127+
def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs):
128+
# Process input arguments
129+
flat_dist_args = flatten(args)
130+
if in_placements is not None:
131+
assert len(in_placements) == len(flat_dist_args), (
132+
f"in_placements length {len(in_placements)} does not match "
133+
f"number of input args {len(flat_dist_args)}!"
134+
)
135+
136+
flat_local_args = []
137+
seen_dist_tensor = False
138+
139+
for idx, arg in enumerate(flat_dist_args):
140+
if dist.auto_parallel.api.is_dist_tensor(arg):
141+
dist_tensor = arg
142+
if process_mesh is None:
143+
if paddle.in_dynamic_mode():
144+
process_mesh = dist_tensor.process_mesh
145+
else:
146+
process_mesh = dist_tensor.dist_attr().process_mesh
147+
148+
seen_dist_tensor = True
149+
150+
if in_placements is not None:
151+
in_placement = in_placements[idx]
152+
if in_placement is None:
153+
if paddle.in_dynamic_mode():
154+
in_placement = dist_tensor.placements
155+
else:
156+
in_placement = dist_tensor.dist_attr().placements
157+
else:
158+
if paddle.in_dynamic_mode():
159+
if in_placement != dist_tensor.placements:
160+
if reshard_inputs:
161+
dist_tensor = dist.reshard(
162+
dist_tensor, process_mesh, in_placement
163+
)
164+
else:
165+
raise ValueError(
166+
f"in_placement {in_placement} does not match dist_tensor.placements {dist_tensor.placements}"
167+
)
168+
169+
else:
170+
if (
171+
in_placement
172+
!= dist_tensor.dist_attr().placements
173+
):
174+
if reshard_inputs:
175+
dist_tensor = dist.reshard(
176+
dist_tensor, process_mesh, in_placement
177+
)
178+
else:
179+
raise ValueError(
180+
f"in_placement {in_placement} does not match dist_tensor.dist_attr().placements {dist_tensor.dist_attr().placements}"
181+
"If reshard_inputs is wanted, set "
182+
"reshard_inputs=True to local_map."
183+
)
184+
local_tensor = dist.auto_parallel.api.dtensor_to_local(
185+
dist_tensor, process_mesh, in_placement
186+
)
187+
flat_local_args.append(local_tensor)
188+
else:
189+
flat_local_args.append(arg)
190+
191+
local_args = pack_sequence_as(args, flat_local_args)
192+
out = func(*local_args, **kwargs)
193+
original_out = out
194+
if seen_dist_tensor:
195+
flat_out = flatten(out)
196+
assert len(flat_out) == len(out_placements), (
197+
"local_map requires one PlacementType for each output value, "
198+
f"got {len(out_placements)} placements but expected "
199+
f"{len(flat_out)}!"
200+
)
201+
202+
flat_dist_and_arg_out = []
203+
for out, out_placement in zip(flat_out, out_placements):
204+
if paddle.in_dynamic_mode():
205+
if isinstance(out, paddle.Tensor):
206+
assert not dist.auto_parallel.api.is_dist_tensor(
207+
out
208+
), f"Expected dense tensor output but got {type(out)}: {out}"
209+
210+
flat_dist_and_arg_out.append(
211+
dist.auto_parallel.api.dtensor_from_local(
212+
out, process_mesh, out_placement
213+
)
214+
)
215+
else:
216+
assert out_placement is None, (
217+
f"Expected None placements for non-tensor output {out} "
218+
f"but got {out_placement}!"
219+
)
220+
flat_dist_and_arg_out.append(out)
221+
else:
222+
if isinstance(out, paddle.base.libpaddle.pir.Value):
223+
assert not dist.auto_parallel.api.is_dist_tensor(
224+
out
225+
), f"Expected dense tensor output but got {type(out)}: {out}"
226+
227+
flat_dist_and_arg_out.append(
228+
dist.auto_parallel.api.dtensor_from_local(
229+
out, process_mesh, out_placement
230+
)
231+
)
232+
else:
233+
assert out_placement is None, (
234+
f"Expected None placements for non-tensor output {out} "
235+
f"but got {out_placement}!"
236+
)
237+
flat_dist_and_arg_out.append(out)
238+
return pack_sequence_as(original_out, flat_dist_and_arg_out)
239+
else:
240+
flat_out = flatten(out)
241+
flat_dist_and_arg_out = []
242+
for out, out_placement in zip(flat_out, out_placements):
243+
if out_placement is not None:
244+
assert (
245+
process_mesh is not None
246+
), "process_mesh must be specified when out_placements is not None"
247+
flat_dist_and_arg_out.append(
248+
dist.auto_parallel.api.dtensor_from_local(
249+
out, process_mesh, out_placement
250+
)
251+
)
252+
else:
253+
flat_dist_and_arg_out.append(out)
254+
return pack_sequence_as(original_out, flat_dist_and_arg_out)
255+
256+
return functools.partial(wrapped, process_mesh)

test/auto_parallel/hybrid_strategy/single_llama_model.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -256,35 +256,20 @@ def forward(self, prediction_scores, masked_lm_labels):
256256
prediction_scores.astype("float32"),
257257
masked_lm_labels.unsqueeze(2),
258258
)
259-
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
260259
if paddle.device.is_compiled_with_xpu():
261260

262-
class LocalLossLayer(paddle.distributed.LocalLayer):
263-
def __init__(self, out_dist_attrs, grad_dist_attrs):
264-
super().__init__(out_dist_attrs, grad_dist_attrs)
265-
266-
def forward(self, x, mask):
267-
masked_lm_loss = paddle.masked_select(x, mask).astype(
268-
"float32"
269-
)
270-
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
271-
return loss.unsqueeze(0)
272-
273-
out_dist_attrs = [
274-
(
275-
masked_lm_loss.process_mesh,
276-
[dist.Shard(0), dist.Replicate()],
277-
),
278-
]
279-
grad_dist_attrs = [
280-
(
281-
masked_lm_loss.process_mesh,
282-
[dist.Shard(0), dist.Replicate()],
283-
),
284-
None,
285-
]
286-
loss_func = LocalLossLayer(out_dist_attrs, grad_dist_attrs)
287-
261+
def LocalLoss(x, mask):
262+
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
263+
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
264+
return loss.unsqueeze(0)
265+
266+
loss_func = dist.local_map(
267+
LocalLoss,
268+
[[dist.Shard(0), dist.Replicate()]],
269+
[[dist.Shard(0), dist.Replicate()], None],
270+
masked_lm_loss.process_mesh,
271+
True,
272+
)
288273
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)
289274
loss = loss.mean()
290275
return loss

0 commit comments

Comments
 (0)