Skip to content

新增API local_map #71804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
SequenceParallelEnd,
)
from .auto_parallel.local_layer import LocalLayer
from .auto_parallel.local_map import local_map
from .auto_parallel.placement_type import (
Partial,
Replicate,
Expand Down Expand Up @@ -192,6 +193,7 @@
"Strategy",
"DistModel",
"LocalLayer",
"local_map",
"unshard_dtensor",
"parallelize",
"SequenceParallelEnd",
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,29 @@ def replicate_layer_params_and_buffers(
)


def is_dist_tensor(tensor) -> bool:
"""
Check if an input is a dist_tensor in both dynamic and static modes.

Args:
tensor: The input to check

Returns:
bool: True if the input is a dist_tensor, False otherwise
"""
if paddle.in_dynamic_mode():
return (
isinstance(tensor, paddle.Tensor)
and hasattr(tensor, 'is_dist')
and tensor.is_dist()
)
else:
return (
isinstance(tensor, paddle.base.libpaddle.pir.Value)
and tensor.dist_attr() is not None
)


class _ShardOptimizer(Optimizer):
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert (
Expand Down
256 changes: 256 additions & 0 deletions python/paddle/distributed/auto_parallel/local_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Callable

import paddle
import paddle.distributed as dist
from paddle.utils import flatten, pack_sequence_as

if TYPE_CHECKING:
from paddle.distributed import ProcessMesh


def local_map(
func: Callable[..., Any],
out_placements: list[list[dist.Placement]],
in_placements: list[list[dist.Placement]] | None = None,
process_mesh: ProcessMesh | None = None,
reshard_inputs: bool = False,
) -> Callable[..., Any]:
"""
The `local_map` API allows users to pass dist_tensors to a function that is written
to be applied on ``paddle.Tensor`` s. It works by extracting the local components
of dist_tensors, calling the function, and wrapping the outputs as dist_tensors
according to the ``out_placements``.

Args:
func (Callable): The function to be applied on each local shard of dist_tensors.

out_placements (list[list[dist.Placement]]):
The desired placements for each output tensor. Must be a list where each element
is a list of Placement objects specifying the distribution strategy for that
output tensor. The length of the outer list must match the number of outputs
from ``func``. For non-tensor outputs, the corresponding placement must be None.
When there are no dist_tensor inputs, process_mesh must be specified to use
non-None placements.

in_placements (Optional[list[list[dist.Placement]]], optional):
The required placements for each input tensor. If specified, must be a list
where each element is a list of Placement objects defining the distribution
strategy for that input tensor. The length of the outer list must match the
number of input tensors.
Default: None

process_mesh (ProcessMesh, optional):
The process mesh that all dist_tensors are placed on. If not specified,
this will be inferred from the input dist_tensors' process mesh.
local_map requires all dist_tensors to be placed on the same process mesh.
Must be specified when there are no dist_tensor inputs but out_placements
contains non-None values.
Default: None

reshard_inputs (bool, optional):
the bool value indicating whether to reshard the input :dist_tensors when
their placements are different from the required input placements. If this
value is ``False`` and some :dist_tensor input has a different placement,
an exception will be raised. Default: False.

Returns:
Callable: A function that applies func to local shards of input dist_tensors and returns dist_tensors or original values.

Example:
.. code-block:: python

>>> from __future__ import annotations
>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle import Tensor
>>> from paddle.distributed import ProcessMesh

>>> def custom_function(x):
... mask = paddle.zeros_like(x)
... if dist.get_rank() == 0:
... mask[1:3] = 1
... else:
... mask[4:7] = 1
... x = x * mask
... mask_sum = paddle.sum(x)
... mask_sum = mask_sum / mask.sum()
... return mask_sum

>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> dist.init_parallel_env()
>>> mesh = ProcessMesh([0, 1], dim_names=["x"])
>>> local_input = paddle.arange(0, 10, dtype="float32")
>>> local_input = local_input + dist.get_rank()
>>> input_dist = dist.auto_parallel.api.dtensor_from_local(
... local_input, mesh, [dist.Shard(0)]
... )
>>> wrapped_func = dist.local_map(
... custom_function,
... out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
... in_placements=[[dist.Shard(0)]],
... process_mesh=mesh
... )
>>> output_dist = wrapped_func(input_dist)

>>> local_value = output_dist._local_value()
>>> gathered_values: list[Tensor] = []
>>> dist.all_gather(gathered_values, local_value)

>>> print(f"[Rank 0] local_value={gathered_values[0].item()}")
[Rank 0] local_value=1.5
>>> print(f"[Rank 1] local_value={gathered_values[1].item()}")
[Rank 1] local_value=6.0
>>> print(f"global_value (distributed)={output_dist.item()}")
global_value (distributed)=7.5

>>> # This case needs to be executed in a multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py
"""

def wrapped(process_mesh: ProcessMesh | None, *args, **kwargs):
# Process input arguments
flat_dist_args = flatten(args)
if in_placements is not None:
assert len(in_placements) == len(flat_dist_args), (
f"in_placements length {len(in_placements)} does not match "
f"number of input args {len(flat_dist_args)}!"
)

flat_local_args = []
seen_dist_tensor = False

for idx, arg in enumerate(flat_dist_args):
if dist.auto_parallel.api.is_dist_tensor(arg):
dist_tensor = arg
if process_mesh is None:
if paddle.in_dynamic_mode():
process_mesh = dist_tensor.process_mesh
else:
process_mesh = dist_tensor.dist_attr().process_mesh

seen_dist_tensor = True

if in_placements is not None:
in_placement = in_placements[idx]
if in_placement is None:
if paddle.in_dynamic_mode():
in_placement = dist_tensor.placements
else:
in_placement = dist_tensor.dist_attr().placements
else:
if paddle.in_dynamic_mode():
if in_placement != dist_tensor.placements:
if reshard_inputs:
dist_tensor = dist.reshard(
dist_tensor, process_mesh, in_placement
)
else:
raise ValueError(
f"in_placement {in_placement} does not match dist_tensor.placements {dist_tensor.placements}"
)

else:
if (
in_placement
!= dist_tensor.dist_attr().placements
):
if reshard_inputs:
dist_tensor = dist.reshard(
dist_tensor, process_mesh, in_placement
)
else:
raise ValueError(
f"in_placement {in_placement} does not match dist_tensor.dist_attr().placements {dist_tensor.dist_attr().placements}"
"If reshard_inputs is wanted, set "
"reshard_inputs=True to local_map."
)
local_tensor = dist.auto_parallel.api.dtensor_to_local(
dist_tensor, process_mesh, in_placement
)
flat_local_args.append(local_tensor)
else:
flat_local_args.append(arg)

local_args = pack_sequence_as(args, flat_local_args)
out = func(*local_args, **kwargs)
original_out = out
if seen_dist_tensor:
flat_out = flatten(out)
assert len(flat_out) == len(out_placements), (
"local_map requires one PlacementType for each output value, "
f"got {len(out_placements)} placements but expected "
f"{len(flat_out)}!"
)

flat_dist_and_arg_out = []
for out, out_placement in zip(flat_out, out_placements):
if paddle.in_dynamic_mode():
if isinstance(out, paddle.Tensor):
assert not dist.auto_parallel.api.is_dist_tensor(
out
), f"Expected dense tensor output but got {type(out)}: {out}"

flat_dist_and_arg_out.append(
dist.auto_parallel.api.dtensor_from_local(
out, process_mesh, out_placement
)
)
else:
assert out_placement is None, (
f"Expected None placements for non-tensor output {out} "
f"but got {out_placement}!"
)
flat_dist_and_arg_out.append(out)
else:
if isinstance(out, paddle.base.libpaddle.pir.Value):
assert not dist.auto_parallel.api.is_dist_tensor(
out
), f"Expected dense tensor output but got {type(out)}: {out}"

flat_dist_and_arg_out.append(
dist.auto_parallel.api.dtensor_from_local(
out, process_mesh, out_placement
)
)
else:
assert out_placement is None, (
f"Expected None placements for non-tensor output {out} "
f"but got {out_placement}!"
)
flat_dist_and_arg_out.append(out)
return pack_sequence_as(original_out, flat_dist_and_arg_out)
else:
flat_out = flatten(out)
flat_dist_and_arg_out = []
for out, out_placement in zip(flat_out, out_placements):
if out_placement is not None:
assert (
process_mesh is not None
), "process_mesh must be specified when out_placements is not None"
flat_dist_and_arg_out.append(
dist.auto_parallel.api.dtensor_from_local(
out, process_mesh, out_placement
)
)
else:
flat_dist_and_arg_out.append(out)
return pack_sequence_as(original_out, flat_dist_and_arg_out)

return functools.partial(wrapped, process_mesh)
39 changes: 12 additions & 27 deletions test/auto_parallel/hybrid_strategy/single_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,35 +256,20 @@ def forward(self, prediction_scores, masked_lm_labels):
prediction_scores.astype("float32"),
masked_lm_labels.unsqueeze(2),
)
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
if paddle.device.is_compiled_with_xpu():

class LocalLossLayer(paddle.distributed.LocalLayer):
def __init__(self, out_dist_attrs, grad_dist_attrs):
super().__init__(out_dist_attrs, grad_dist_attrs)

def forward(self, x, mask):
masked_lm_loss = paddle.masked_select(x, mask).astype(
"float32"
)
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
return loss.unsqueeze(0)

out_dist_attrs = [
(
masked_lm_loss.process_mesh,
[dist.Shard(0), dist.Replicate()],
),
]
grad_dist_attrs = [
(
masked_lm_loss.process_mesh,
[dist.Shard(0), dist.Replicate()],
),
None,
]
loss_func = LocalLossLayer(out_dist_attrs, grad_dist_attrs)

def LocalLoss(x, mask):
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
return loss.unsqueeze(0)

loss_func = dist.local_map(
LocalLoss,
[[dist.Shard(0), dist.Replicate()]],
[[dist.Shard(0), dist.Replicate()], None],
masked_lm_loss.process_mesh,
True,
)
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)
loss = loss.mean()
return loss
Expand Down
Loading
Loading