Skip to content

Commit 798c1a6

Browse files
committed
修改格式
1 parent aafbc74 commit 798c1a6

File tree

1 file changed

+59
-69
lines changed

1 file changed

+59
-69
lines changed

python/paddle/distributed/auto_parallel/local_map.py

Lines changed: 59 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import functools
17-
from typing import TYPE_CHECKING, Callable
17+
from typing import TYPE_CHECKING, Callable, Any
1818

1919
import paddle
2020
import paddle.distributed as dist
@@ -25,10 +25,10 @@
2525

2626

2727
def local_map(
28-
func: Callable,
28+
func: Callable[..., Any],
2929
out_placements: list[list[dist.Placement]],
30-
in_placements: list[list[dist.Placement]] | None,
31-
process_mesh: ProcessMesh | None,
30+
in_placements: list[list[dist.Placement]] | None = None,
31+
process_mesh: ProcessMesh | None = None,
3232
reshard_inputs: bool = False,
3333
):
3434
"""
@@ -48,11 +48,12 @@ def local_map(
4848
When there are no dist_tensor inputs, process_mesh must be specified to use
4949
non-None placements.
5050
51-
in_placements (Optional[list[list[dist.Placement]]]):
51+
in_placements (Optional[list[list[dist.Placement]]], optional):
5252
The required placements for each input tensor. If specified, must be a list
5353
where each element is a list of Placement objects defining the distribution
5454
strategy for that input tensor. The length of the outer list must match the
5555
number of input tensors.
56+
Default: None
5657
5758
process_mesh (ProcessMesh, optional):
5859
The process mesh that all dist_tensors are placed on. If not specified,
@@ -63,80 +64,69 @@ def local_map(
6364
Default: None
6465
6566
reshard_inputs (bool, optional):
66-
the bool value indicating whether to reshard the input :dist_tensor` s when
67+
the bool value indicating whether to reshard the input :dist_tensors when
6768
their placements are different from the required input placements. If this
6869
value is ``False`` and some :dist_tensor input has a different placement,
6970
an exception will be raised. Default: False.
7071
7172
Returns:
72-
A ``Callable`` that applies ``func`` to each local shard of the input dist_tensors
73-
and returns dist_tensors constructed from the return values of ``func``.
73+
Callable: A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values.
74+
75+
Example:
76+
.. code-block:: python
7477
75-
Raises:
76-
AssertionError: If the number of output placements does not match the number
77-
of function outputs.
78+
>>> from __future__ import annotations
79+
>>> import paddle
80+
>>> import paddle.distributed as dist
81+
>>> from paddle import Tensor
82+
>>> from paddle.distributed import ProcessMesh
7883
79-
AssertionError: If a non-tensor output has a non-None placement specified.
84+
>>> def custom_function(x):
85+
... mask = paddle.zeros_like(x)
86+
... if dist.get_rank() == 0:
87+
... mask[1:3] = 1
88+
... else:
89+
... mask[4:7] = 1
90+
... x = x * mask
91+
... mask_sum = paddle.sum(x)
92+
... mask_sum = mask_sum / mask.sum()
93+
... return mask_sum
8094
81-
AssertionError: If process_mesh is None and there are no dist_tensor inputs
82-
but out_placements contains non-None values.
95+
>>> # Initialize distributed environment
96+
>>> dist.init_parallel_env()
97+
>>> mesh = ProcessMesh([0, 1], dim_names=["x"])
8398
84-
ValueError: If the input dist_tensor placements don't match the required
85-
in_placements.
99+
>>> # Create input data
100+
>>> local_input = paddle.arange(0, 10, dtype="float32")
101+
>>> local_input = local_input + dist.get_rank()
86102
87-
Example:
88-
>>> from __future__ import annotations
89-
>>> import paddle
90-
>>> import paddle.distributed as dist
91-
>>> from paddle import Tensor
92-
>>> from paddle.distributed import ProcessMesh
93-
>>>
94-
>>> def custom_function(x):
95-
>>> mask = paddle.zeros_like(x)
96-
>>> if dist.get_rank() == 0:
97-
>>> mask[1:3] = 1
98-
>>> else:
99-
>>> mask[4:7] = 1
100-
>>> x = x * mask
101-
>>> mask_sum = paddle.sum(x)
102-
>>> mask_sum = mask_sum / mask.sum()
103-
>>> return mask_sum
104-
>>>
105-
>>> # Initialize distributed environment
106-
>>> dist.init_parallel_env()
107-
>>> mesh = ProcessMesh([0, 1], dim_names=["x"])
108-
>>>
109-
>>> # Create input data
110-
>>> local_input = paddle.arange(0, 10, dtype="float32")
111-
>>> local_input = local_input + dist.get_rank()
112-
>>>
113-
>>> # Convert to distributed tensor
114-
>>> input_dist = dist.auto_parallel.api.dtensor_from_local(
115-
>>> local_input, mesh, [dist.Shard(0)]
116-
>>> )
117-
>>>
118-
>>> # Wrap function with local_map
119-
>>> wrapped_func = dist.local_map(
120-
>>> custom_function,
121-
>>> out_placements=[dist.Partial(dist.ReduceType.kRedSum)],
122-
>>> in_placements=(dist.Shard(0),),
123-
>>> process_mesh=mesh
124-
>>> )
125-
>>>
126-
>>> # Apply function to distributed tensor
127-
>>> output_dist = wrapped_func(input_dist)
128-
>>>
129-
>>> # Collect and print results
130-
>>> local_value = output_dist._local_value()
131-
>>> gathered_values: list[Tensor] = []
132-
>>> dist.all_gather(gathered_values, local_value)
133-
>>>
134-
>>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
135-
[Rank 0] local_value=1.5
136-
>>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
137-
[Rank 1] local_value=6.0
138-
>>> print(f"global_value (distributed)={output_dist.item()}")
139-
global_value (distributed)=7.5
103+
>>> # Convert to distributed tensor
104+
>>> input_dist = dist.auto_parallel.api.dtensor_from_local(
105+
... local_input, mesh, [dist.Shard(0)]
106+
... )
107+
108+
>>> # Wrap function with local_map
109+
>>> wrapped_func = dist.local_map(
110+
... custom_function,
111+
... out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
112+
... in_placements=[[dist.Shard(0)]],
113+
... process_mesh=mesh
114+
... )
115+
116+
>>> # Apply function to distributed tensor
117+
>>> output_dist = wrapped_func(input_dist)
118+
119+
>>> # Collect and print results
120+
>>> local_value = output_dist._local_value()
121+
>>> gathered_values: list[Tensor] = []
122+
>>> dist.all_gather(gathered_values, local_value)
123+
124+
>>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
125+
[Rank 0] local_value=1.5
126+
>>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
127+
[Rank 1] local_value=6.0
128+
>>> print(f"global_value (distributed)={output_dist.item()}")
129+
global_value (distributed)=7.5
140130
"""
141131

142132
def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs):

0 commit comments

Comments
 (0)