14
14
from __future__ import annotations
15
15
16
16
import functools
17
- from typing import TYPE_CHECKING , Callable
17
+ from typing import TYPE_CHECKING , Callable , Any
18
18
19
19
import paddle
20
20
import paddle .distributed as dist
25
25
26
26
27
27
def local_map (
28
- func : Callable ,
28
+ func : Callable [..., Any ] ,
29
29
out_placements : list [list [dist .Placement ]],
30
- in_placements : list [list [dist .Placement ]] | None ,
31
- process_mesh : ProcessMesh | None ,
30
+ in_placements : list [list [dist .Placement ]] | None = None ,
31
+ process_mesh : ProcessMesh | None = None ,
32
32
reshard_inputs : bool = False ,
33
33
):
34
34
"""
@@ -48,11 +48,12 @@ def local_map(
48
48
When there are no dist_tensor inputs, process_mesh must be specified to use
49
49
non-None placements.
50
50
51
- in_placements (Optional[list[list[dist.Placement]]]):
51
+ in_placements (Optional[list[list[dist.Placement]]], optional ):
52
52
The required placements for each input tensor. If specified, must be a list
53
53
where each element is a list of Placement objects defining the distribution
54
54
strategy for that input tensor. The length of the outer list must match the
55
55
number of input tensors.
56
+ Default: None
56
57
57
58
process_mesh (ProcessMesh, optional):
58
59
The process mesh that all dist_tensors are placed on. If not specified,
@@ -63,80 +64,69 @@ def local_map(
63
64
Default: None
64
65
65
66
reshard_inputs (bool, optional):
66
- the bool value indicating whether to reshard the input :dist_tensor` s when
67
+ the bool value indicating whether to reshard the input :dist_tensors when
67
68
their placements are different from the required input placements. If this
68
69
value is ``False`` and some :dist_tensor input has a different placement,
69
70
an exception will be raised. Default: False.
70
71
71
72
Returns:
72
- A ``Callable`` that applies ``func`` to each local shard of the input dist_tensors
73
- and returns dist_tensors constructed from the return values of ``func``.
73
+ Callable: A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values.
74
+
75
+ Example:
76
+ .. code-block:: python
74
77
75
- Raises:
76
- AssertionError: If the number of output placements does not match the number
77
- of function outputs.
78
+ >>> from __future__ import annotations
79
+ >>> import paddle
80
+ >>> import paddle.distributed as dist
81
+ >>> from paddle import Tensor
82
+ >>> from paddle.distributed import ProcessMesh
78
83
79
- AssertionError: If a non-tensor output has a non-None placement specified.
84
+ >>> def custom_function(x):
85
+ ... mask = paddle.zeros_like(x)
86
+ ... if dist.get_rank() == 0:
87
+ ... mask[1:3] = 1
88
+ ... else:
89
+ ... mask[4:7] = 1
90
+ ... x = x * mask
91
+ ... mask_sum = paddle.sum(x)
92
+ ... mask_sum = mask_sum / mask.sum()
93
+ ... return mask_sum
80
94
81
- AssertionError: If process_mesh is None and there are no dist_tensor inputs
82
- but out_placements contains non-None values.
95
+ >>> # Initialize distributed environment
96
+ >>> dist.init_parallel_env()
97
+ >>> mesh = ProcessMesh([0, 1], dim_names=["x"])
83
98
84
- ValueError: If the input dist_tensor placements don't match the required
85
- in_placements.
99
+ >>> # Create input data
100
+ >>> local_input = paddle.arange(0, 10, dtype="float32")
101
+ >>> local_input = local_input + dist.get_rank()
86
102
87
- Example:
88
- >>> from __future__ import annotations
89
- >>> import paddle
90
- >>> import paddle.distributed as dist
91
- >>> from paddle import Tensor
92
- >>> from paddle.distributed import ProcessMesh
93
- >>>
94
- >>> def custom_function(x):
95
- >>> mask = paddle.zeros_like(x)
96
- >>> if dist.get_rank() == 0:
97
- >>> mask[1:3] = 1
98
- >>> else:
99
- >>> mask[4:7] = 1
100
- >>> x = x * mask
101
- >>> mask_sum = paddle.sum(x)
102
- >>> mask_sum = mask_sum / mask.sum()
103
- >>> return mask_sum
104
- >>>
105
- >>> # Initialize distributed environment
106
- >>> dist.init_parallel_env()
107
- >>> mesh = ProcessMesh([0, 1], dim_names=["x"])
108
- >>>
109
- >>> # Create input data
110
- >>> local_input = paddle.arange(0, 10, dtype="float32")
111
- >>> local_input = local_input + dist.get_rank()
112
- >>>
113
- >>> # Convert to distributed tensor
114
- >>> input_dist = dist.auto_parallel.api.dtensor_from_local(
115
- >>> local_input, mesh, [dist.Shard(0)]
116
- >>> )
117
- >>>
118
- >>> # Wrap function with local_map
119
- >>> wrapped_func = dist.local_map(
120
- >>> custom_function,
121
- >>> out_placements=[dist.Partial(dist.ReduceType.kRedSum)],
122
- >>> in_placements=(dist.Shard(0),),
123
- >>> process_mesh=mesh
124
- >>> )
125
- >>>
126
- >>> # Apply function to distributed tensor
127
- >>> output_dist = wrapped_func(input_dist)
128
- >>>
129
- >>> # Collect and print results
130
- >>> local_value = output_dist._local_value()
131
- >>> gathered_values: list[Tensor] = []
132
- >>> dist.all_gather(gathered_values, local_value)
133
- >>>
134
- >>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
135
- [Rank 0] local_value=1.5
136
- >>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
137
- [Rank 1] local_value=6.0
138
- >>> print(f"global_value (distributed)={output_dist.item()}")
139
- global_value (distributed)=7.5
103
+ >>> # Convert to distributed tensor
104
+ >>> input_dist = dist.auto_parallel.api.dtensor_from_local(
105
+ ... local_input, mesh, [dist.Shard(0)]
106
+ ... )
107
+
108
+ >>> # Wrap function with local_map
109
+ >>> wrapped_func = dist.local_map(
110
+ ... custom_function,
111
+ ... out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
112
+ ... in_placements=[[dist.Shard(0)]],
113
+ ... process_mesh=mesh
114
+ ... )
115
+
116
+ >>> # Apply function to distributed tensor
117
+ >>> output_dist = wrapped_func(input_dist)
118
+
119
+ >>> # Collect and print results
120
+ >>> local_value = output_dist._local_value()
121
+ >>> gathered_values: list[Tensor] = []
122
+ >>> dist.all_gather(gathered_values, local_value)
123
+
124
+ >>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
125
+ [Rank 0] local_value=1.5
126
+ >>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
127
+ [Rank 1] local_value=6.0
128
+ >>> print(f"global_value (distributed)={output_dist.item()}")
129
+ global_value (distributed)=7.5
140
130
"""
141
131
142
132
def wrapped (process_mesh : ProcessMesh | None , * args , ** kwargs ):
0 commit comments