Skip to content

新增local_map API的中文文档 #7245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 12, 2025
Merged
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions docs/api/paddle/distributed/local_map_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
.. _cn_api_paddle_distributed_local_map:

local_map
-------------------------------

.. py:function:: paddle.distributed.local_map(func, out_placements, in_placements=None, process_mesh=None, reshard_inputs=False)

local_map 是一个函数装饰器,允许用户将分布式张量(DTensor)传递给为普通张量(Tensor)编写的函数。它通过提取分布式张量的本地分量,调用目标函数,并根据 out_placements 将输出包装为分布式张量来实现这一功能,通过自动处理张量转换,使得用户可以像编写单卡代码一样实现这些局部操作。


参数
:::::::::

- **func** (Callable) - 要应用于分布式张量本地分片的函数
- **out_placements** (list[list[dist.Placement]]) - 指定输出张量的分布策略。外层列表长度必须与函数输出数量匹配,每个内层列表描述对应输出张量的分布方式。对于非张量输出必须设为 None
- **in_placements** (list[list[dist.Placement]] | None) - 指定输入张量的要求分布。如果指定,每个内层列表描述对应输入张量的分布要求。外层列表长度必须与输入张量数量匹配。对于不具有分布式属性的输入应设为 None,默认为 None
- **process_mesh** (Optional[ProcessMesh]) - 计算设备网格。所有分布式张量必须位于同一个 process_mesh 上。如未指定则从输入张量推断
- **reshard_inputs** (bool) - 当输入分布式张量的分布方式与要求的 in_placements 不匹配时,是否自动重分布。默认 False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考 API 参数写法

  • 可选参数应该写为如 in_placements (list[list[dist.Placement]],可选) 的形式
  • 标点符号注意,都需要为中文标点符号。 输入应设为 None,默认为 None 中间用的是英文,修改一下。另外结尾的句号都不要忘记
  • 对于每个可选参数们都要说明:1. 默认值是什么 2. 默认值的含义。 像 process_mesh 就没说默认值和默认值的含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

@SigureMo SigureMo Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可选参数应该写为如 in_placements (list[list[dist.Placement]],可选) 的形式

不懂不要乱说啊,类型提示的 Optional 和这里的可选语义完全不一样

类型提示里的 Optional[T] 代表输入类型可以是 T 或者 None,现代语法更常用的是 T | None

可选参数指参数有默认值

两者没有强关联性!

def fn1(x: int): ... # 类型只能是 int
def fn2(x: int | None): ... # 类型可以是 int | None,注意不是可选参数,因为没有默认值
def fn3(x: int = 1): ... # 类型只能是 int,有默认值,即可选参数
def fn4(x: int | None = None): ... # 类型可以是 int | None,同时是可选参数,默认值为 None

Copy link
Collaborator

@sunzhongkai588 sunzhongkai588 Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SigureMo 我看源码对应的类型提示 in_placements: list[list[dist.Placement]] | None = None 不是可选参数么

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前并不是,这是昨天刚改的

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

而且 list[list[dist.Placement]] | None 是类型提示,= None 不是类型提示

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一师傅回的好快。所以这里应该改为 in_placements (list[list[dist.Placement]] | None,可选) 是么

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以这样


返回
:::::::::

返回一个可调用对象(Callable),该对象将 func 应用于输入分布式张量的每个本地分片,并根据返回值构造新的分布式张量。


异常抛出情况
:::::::::

- **AssertionError** - 当输出分布策略的数量与函数输出的数量不匹配时抛出。
- **AssertionError** - 当非张量输出指定了非 None 的分布策略时抛出。
- **AssertionError** - 当 process_mesh 为 None 且没有分布式张量输入,但 out_placements 包含非 None 值时抛出。
- **ValueError** - 当输入分布式张量的分布方式与要求的 in_placements 不匹配,且 reshard_inputs 为 False 时抛出。


代码示例
:::::::::

.. code-block:: python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文示例代码统一用 COPY-FROM: paddle.xxx 的形式,xxx 为 API 的调用路径,目的为和英文代码保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


import paddle
import paddle.distributed as dist
from paddle import Tensor
from paddle.distributed import ProcessMesh

def custom_function(x):
mask = paddle.zeros_like(x)
if dist.get_rank() == 0:
mask[1:3] = 1
else:
mask[4:7] = 1
x = x * mask
mask_sum = paddle.sum(x)
mask_sum = mask_sum / mask.sum()
return mask_sum

dist.init_parallel_env()
mesh = ProcessMesh([0, 1], dim_names=["x"])

local_input = paddle.arange(0, 10, dtype='float32')
local_input = local_input + dist.get_rank()

input_dist = dist.auto_parallel.api.dtensor_from_local(
local_input,
mesh,
[dist.Shard(0)]
)

# 使用 local_map 包装函数
wrapped_func = dist.local_map(
custom_function,
out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
in_placements=[[dist.Shard(0)]],
process_mesh=mesh
)

# 应用函数到分布式张量
output_dist = wrapped_func(input_dist)

# 收集并打印结果
local_value = output_dist._local_value()
gathered_values: list[Tensor] = []
dist.all_gather(gathered_values, local_value)
print(f"[Rank 0] local_value={gathered_values[0].item()}")
# [Rank 0] local_value=1.5
print(f"[Rank 1] local_value={gathered_values[1].item()}")
# [Rank 1] local_value=6.0
print(f"global_value (distributed)={output_dist.item()}")
# global_value (distributed)=7.5