|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +import functools |
| 17 | +from typing import TYPE_CHECKING, Any, Callable |
| 18 | + |
| 19 | +import paddle |
| 20 | +import paddle.distributed as dist |
| 21 | +from paddle.utils import flatten, pack_sequence_as |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from paddle.distributed import ProcessMesh |
| 25 | + |
| 26 | + |
| 27 | +def local_map( |
| 28 | + func: Callable[..., Any], |
| 29 | + out_placements: list[list[dist.Placement]], |
| 30 | + in_placements: list[list[dist.Placement]] | None = None, |
| 31 | + process_mesh: ProcessMesh | None = None, |
| 32 | + reshard_inputs: bool = False, |
| 33 | +) -> Callable[..., Any]: |
| 34 | + """ |
| 35 | + The `local_map` API allows users to pass dist_tensors to a function that is written |
| 36 | + to be applied on ``paddle.Tensor`` s. It works by extracting the local components |
| 37 | + of dist_tensors, calling the function, and wrapping the outputs as dist_tensors |
| 38 | + according to the ``out_placements``. |
| 39 | +
|
| 40 | + Args: |
| 41 | + func (Callable): The function to be applied on each local shard of dist_tensors. |
| 42 | +
|
| 43 | + out_placements (list[list[dist.Placement]]): |
| 44 | + The desired placements for each output tensor. Must be a list where each element |
| 45 | + is a list of Placement objects specifying the distribution strategy for that |
| 46 | + output tensor. The length of the outer list must match the number of outputs |
| 47 | + from ``func``. For non-tensor outputs, the corresponding placement must be None. |
| 48 | + When there are no dist_tensor inputs, process_mesh must be specified to use |
| 49 | + non-None placements. |
| 50 | +
|
| 51 | + in_placements (Optional[list[list[dist.Placement]]], optional): |
| 52 | + The required placements for each input tensor. If specified, must be a list |
| 53 | + where each element is a list of Placement objects defining the distribution |
| 54 | + strategy for that input tensor. The length of the outer list must match the |
| 55 | + number of input tensors. |
| 56 | + Default: None |
| 57 | +
|
| 58 | + process_mesh (ProcessMesh, optional): |
| 59 | + The process mesh that all dist_tensors are placed on. If not specified, |
| 60 | + this will be inferred from the input dist_tensors' process mesh. |
| 61 | + local_map requires all dist_tensors to be placed on the same process mesh. |
| 62 | + Must be specified when there are no dist_tensor inputs but out_placements |
| 63 | + contains non-None values. |
| 64 | + Default: None |
| 65 | +
|
| 66 | + reshard_inputs (bool, optional): |
| 67 | + the bool value indicating whether to reshard the input :dist_tensors when |
| 68 | + their placements are different from the required input placements. If this |
| 69 | + value is ``False`` and some :dist_tensor input has a different placement, |
| 70 | + an exception will be raised. Default: False. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + Callable: A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values. |
| 74 | +
|
| 75 | + Example: |
| 76 | + .. code-block:: python |
| 77 | +
|
| 78 | + >>> from __future__ import annotations |
| 79 | + >>> import paddle |
| 80 | + >>> import paddle.distributed as dist |
| 81 | + >>> from paddle import Tensor |
| 82 | + >>> from paddle.distributed import ProcessMesh |
| 83 | +
|
| 84 | + >>> def custom_function(x): |
| 85 | + ... mask = paddle.zeros_like(x) |
| 86 | + ... if dist.get_rank() == 0: |
| 87 | + ... mask[1:3] = 1 |
| 88 | + ... else: |
| 89 | + ... mask[4:7] = 1 |
| 90 | + ... x = x * mask |
| 91 | + ... mask_sum = paddle.sum(x) |
| 92 | + ... mask_sum = mask_sum / mask.sum() |
| 93 | + ... return mask_sum |
| 94 | +
|
| 95 | + >>> # doctest: +REQUIRES(env:DISTRIBUTED) |
| 96 | + >>> dist.init_parallel_env() |
| 97 | + >>> mesh = ProcessMesh([0, 1], dim_names=["x"]) |
| 98 | + >>> local_input = paddle.arange(0, 10, dtype="float32") |
| 99 | + >>> local_input = local_input + dist.get_rank() |
| 100 | + >>> input_dist = dist.auto_parallel.api.dtensor_from_local( |
| 101 | + ... local_input, mesh, [dist.Shard(0)] |
| 102 | + ... ) |
| 103 | + >>> wrapped_func = dist.local_map( |
| 104 | + ... custom_function, |
| 105 | + ... out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]], |
| 106 | + ... in_placements=[[dist.Shard(0)]], |
| 107 | + ... process_mesh=mesh |
| 108 | + ... ) |
| 109 | + >>> output_dist = wrapped_func(input_dist) |
| 110 | +
|
| 111 | + >>> local_value = output_dist._local_value() |
| 112 | + >>> gathered_values: list[Tensor] = [] |
| 113 | + >>> dist.all_gather(gathered_values, local_value) |
| 114 | +
|
| 115 | + >>> print(f"[Rank 0] local_value={gathered_values[0].item()}") |
| 116 | + [Rank 0] local_value=1.5 |
| 117 | + >>> print(f"[Rank 1] local_value={gathered_values[1].item()}") |
| 118 | + [Rank 1] local_value=6.0 |
| 119 | + >>> print(f"global_value (distributed)={output_dist.item()}") |
| 120 | + global_value (distributed)=7.5 |
| 121 | +
|
| 122 | + >>> # This case needs to be executed in a multi-card environment |
| 123 | + >>> # export CUDA_VISIBLE_DEVICES=0,1 |
| 124 | + >>> # python -m paddle.distributed.launch {test_case}.py |
| 125 | + """ |
| 126 | + |
| 127 | + def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs): |
| 128 | + # Process input arguments |
| 129 | + flat_dist_args = flatten(args) |
| 130 | + if in_placements is not None: |
| 131 | + assert len(in_placements) == len(flat_dist_args), ( |
| 132 | + f"in_placements length {len(in_placements)} does not match " |
| 133 | + f"number of input args {len(flat_dist_args)}!" |
| 134 | + ) |
| 135 | + |
| 136 | + flat_local_args = [] |
| 137 | + seen_dist_tensor = False |
| 138 | + |
| 139 | + for idx, arg in enumerate(flat_dist_args): |
| 140 | + if dist.auto_parallel.api.is_dist_tensor(arg): |
| 141 | + dist_tensor = arg |
| 142 | + if process_mesh is None: |
| 143 | + if paddle.in_dynamic_mode(): |
| 144 | + process_mesh = dist_tensor.process_mesh |
| 145 | + else: |
| 146 | + process_mesh = dist_tensor.dist_attr().process_mesh |
| 147 | + |
| 148 | + seen_dist_tensor = True |
| 149 | + |
| 150 | + if in_placements is not None: |
| 151 | + in_placement = in_placements[idx] |
| 152 | + if in_placement is None: |
| 153 | + if paddle.in_dynamic_mode(): |
| 154 | + in_placement = dist_tensor.placements |
| 155 | + else: |
| 156 | + in_placement = dist_tensor.dist_attr().placements |
| 157 | + else: |
| 158 | + if paddle.in_dynamic_mode(): |
| 159 | + if in_placement != dist_tensor.placements: |
| 160 | + if reshard_inputs: |
| 161 | + dist_tensor = dist.reshard( |
| 162 | + dist_tensor, process_mesh, in_placement |
| 163 | + ) |
| 164 | + else: |
| 165 | + raise ValueError( |
| 166 | + f"in_placement {in_placement} does not match dist_tensor.placements {dist_tensor.placements}" |
| 167 | + ) |
| 168 | + |
| 169 | + else: |
| 170 | + if ( |
| 171 | + in_placement |
| 172 | + != dist_tensor.dist_attr().placements |
| 173 | + ): |
| 174 | + if reshard_inputs: |
| 175 | + dist_tensor = dist.reshard( |
| 176 | + dist_tensor, process_mesh, in_placement |
| 177 | + ) |
| 178 | + else: |
| 179 | + raise ValueError( |
| 180 | + f"in_placement {in_placement} does not match dist_tensor.dist_attr().placements {dist_tensor.dist_attr().placements}" |
| 181 | + "If reshard_inputs is wanted, set " |
| 182 | + "reshard_inputs=True to local_map." |
| 183 | + ) |
| 184 | + local_tensor = dist.auto_parallel.api.dtensor_to_local( |
| 185 | + dist_tensor, process_mesh, in_placement |
| 186 | + ) |
| 187 | + flat_local_args.append(local_tensor) |
| 188 | + else: |
| 189 | + flat_local_args.append(arg) |
| 190 | + |
| 191 | + local_args = pack_sequence_as(args, flat_local_args) |
| 192 | + out = func(*local_args, **kwargs) |
| 193 | + original_out = out |
| 194 | + if seen_dist_tensor: |
| 195 | + flat_out = flatten(out) |
| 196 | + assert len(flat_out) == len(out_placements), ( |
| 197 | + "local_map requires one PlacementType for each output value, " |
| 198 | + f"got {len(out_placements)} placements but expected " |
| 199 | + f"{len(flat_out)}!" |
| 200 | + ) |
| 201 | + |
| 202 | + flat_dist_and_arg_out = [] |
| 203 | + for out, out_placement in zip(flat_out, out_placements): |
| 204 | + if paddle.in_dynamic_mode(): |
| 205 | + if isinstance(out, paddle.Tensor): |
| 206 | + assert not dist.auto_parallel.api.is_dist_tensor( |
| 207 | + out |
| 208 | + ), f"Expected dense tensor output but got {type(out)}: {out}" |
| 209 | + |
| 210 | + flat_dist_and_arg_out.append( |
| 211 | + dist.auto_parallel.api.dtensor_from_local( |
| 212 | + out, process_mesh, out_placement |
| 213 | + ) |
| 214 | + ) |
| 215 | + else: |
| 216 | + assert out_placement is None, ( |
| 217 | + f"Expected None placements for non-tensor output {out} " |
| 218 | + f"but got {out_placement}!" |
| 219 | + ) |
| 220 | + flat_dist_and_arg_out.append(out) |
| 221 | + else: |
| 222 | + if isinstance(out, paddle.base.libpaddle.pir.Value): |
| 223 | + assert not dist.auto_parallel.api.is_dist_tensor( |
| 224 | + out |
| 225 | + ), f"Expected dense tensor output but got {type(out)}: {out}" |
| 226 | + |
| 227 | + flat_dist_and_arg_out.append( |
| 228 | + dist.auto_parallel.api.dtensor_from_local( |
| 229 | + out, process_mesh, out_placement |
| 230 | + ) |
| 231 | + ) |
| 232 | + else: |
| 233 | + assert out_placement is None, ( |
| 234 | + f"Expected None placements for non-tensor output {out} " |
| 235 | + f"but got {out_placement}!" |
| 236 | + ) |
| 237 | + flat_dist_and_arg_out.append(out) |
| 238 | + return pack_sequence_as(original_out, flat_dist_and_arg_out) |
| 239 | + else: |
| 240 | + flat_out = flatten(out) |
| 241 | + flat_dist_and_arg_out = [] |
| 242 | + for out, out_placement in zip(flat_out, out_placements): |
| 243 | + if out_placement is not None: |
| 244 | + assert ( |
| 245 | + process_mesh is not None |
| 246 | + ), "process_mesh must be specified when out_placements is not None" |
| 247 | + flat_dist_and_arg_out.append( |
| 248 | + dist.auto_parallel.api.dtensor_from_local( |
| 249 | + out, process_mesh, out_placement |
| 250 | + ) |
| 251 | + ) |
| 252 | + else: |
| 253 | + flat_dist_and_arg_out.append(out) |
| 254 | + return pack_sequence_as(original_out, flat_dist_and_arg_out) |
| 255 | + |
| 256 | + return functools.partial(wrapped, process_mesh) |
0 commit comments