-
Notifications
You must be signed in to change notification settings - Fork 825
新增local_map API的中文文档 #7245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
新增local_map API的中文文档 #7245
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
07939de
【Hackathon 7th No.19】对应开发的API中文文档
zty-king e231444
【Hackathon 7th No.19】对应开发的API中文文档
zty-king ace6dcf
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 2a91133
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 53e3952
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into my
zty-king ce022b8
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 8e4ef95
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into my
zty-king fbbb4f7
提交local_map的中文文档
zty-king 5079591
提交local_map的中文文档
zty-king 57cc80f
提交local_map的中文文档
zty-king ffcff3d
提交local_map的中文文档
zty-king e4aecb3
修改文档格式
zty-king 8b138fb
修改文档格式
zty-king 1c5ab83
修改文档格式
zty-king 2f20eed
修改文档格式
zty-king f3811b7
修改文档格式
zty-king cfc270a
修改格式
zty-king 2351345
添加总览
zty-king 1b9d318
修改文档格式
zty-king dd77184
修改文档格式
zty-king File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
.. _cn_api_paddle_distributed_local_map: | ||
|
||
local_map | ||
------------------------------- | ||
|
||
.. py:function:: paddle.distributed.local_map(func, out_placements, in_placements=None, process_mesh=None, reshard_inputs=False) | ||
|
||
local_map 是一个函数装饰器,允许用户将分布式张量(DTensor)传递给为普通张量(Tensor)编写的函数。它通过提取分布式张量的本地分量,调用目标函数,并根据 out_placements 将输出包装为分布式张量来实现这一功能,通过自动处理张量转换,使得用户可以像编写单卡代码一样实现这些局部操作。 | ||
|
||
|
||
参数 | ||
::::::::: | ||
|
||
- **func** (Callable) - 要应用于分布式张量本地分片的函数 | ||
- **out_placements** (list[list[dist.Placement]]) - 指定输出张量的分布策略。外层列表长度必须与函数输出数量匹配,每个内层列表描述对应输出张量的分布方式。对于非张量输出必须设为 None | ||
- **in_placements** (list[list[dist.Placement]] | None) - 指定输入张量的要求分布。如果指定,每个内层列表描述对应输入张量的分布要求。外层列表长度必须与输入张量数量匹配。对于不具有分布式属性的输入应设为 None,默认为 None | ||
- **process_mesh** (Optional[ProcessMesh]) - 计算设备网格。所有分布式张量必须位于同一个 process_mesh 上。如未指定则从输入张量推断 | ||
- **reshard_inputs** (bool) - 当输入分布式张量的分布方式与要求的 in_placements 不匹配时,是否自动重分布。默认 False | ||
|
||
返回 | ||
::::::::: | ||
|
||
返回一个可调用对象(Callable),该对象将 func 应用于输入分布式张量的每个本地分片,并根据返回值构造新的分布式张量。 | ||
|
||
|
||
异常抛出情况 | ||
::::::::: | ||
|
||
- **AssertionError** - 当输出分布策略的数量与函数输出的数量不匹配时抛出。 | ||
- **AssertionError** - 当非张量输出指定了非 None 的分布策略时抛出。 | ||
- **AssertionError** - 当 process_mesh 为 None 且没有分布式张量输入,但 out_placements 包含非 None 值时抛出。 | ||
- **ValueError** - 当输入分布式张量的分布方式与要求的 in_placements 不匹配,且 reshard_inputs 为 False 时抛出。 | ||
|
||
|
||
代码示例 | ||
::::::::: | ||
|
||
.. code-block:: python | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 中文示例代码统一用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle import Tensor | ||
from paddle.distributed import ProcessMesh | ||
|
||
def custom_function(x): | ||
mask = paddle.zeros_like(x) | ||
if dist.get_rank() == 0: | ||
mask[1:3] = 1 | ||
else: | ||
mask[4:7] = 1 | ||
x = x * mask | ||
mask_sum = paddle.sum(x) | ||
mask_sum = mask_sum / mask.sum() | ||
return mask_sum | ||
|
||
dist.init_parallel_env() | ||
mesh = ProcessMesh([0, 1], dim_names=["x"]) | ||
|
||
local_input = paddle.arange(0, 10, dtype='float32') | ||
local_input = local_input + dist.get_rank() | ||
|
||
input_dist = dist.auto_parallel.api.dtensor_from_local( | ||
local_input, | ||
mesh, | ||
[dist.Shard(0)] | ||
) | ||
|
||
# 使用 local_map 包装函数 | ||
wrapped_func = dist.local_map( | ||
custom_function, | ||
out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]], | ||
in_placements=[[dist.Shard(0)]], | ||
process_mesh=mesh | ||
) | ||
|
||
# 应用函数到分布式张量 | ||
output_dist = wrapped_func(input_dist) | ||
|
||
# 收集并打印结果 | ||
local_value = output_dist._local_value() | ||
gathered_values: list[Tensor] = [] | ||
dist.all_gather(gathered_values, local_value) | ||
print(f"[Rank 0] local_value={gathered_values[0].item()}") | ||
# [Rank 0] local_value=1.5 | ||
print(f"[Rank 1] local_value={gathered_values[1].item()}") | ||
# [Rank 1] local_value=6.0 | ||
print(f"global_value (distributed)={output_dist.item()}") | ||
# global_value (distributed)=7.5 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考 API 参数写法
in_placements (list[list[dist.Placement]],可选)
的形式输入应设为 None,默认为 None
中间用的是英文,修改一下。另外结尾的句号都不要忘记There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不懂不要乱说啊,类型提示的 Optional 和这里的可选语义完全不一样!
类型提示里的
Optional[T]
代表输入类型可以是 T 或者 None,现代语法更常用的是T | None
可选参数指参数有默认值
两者没有强关联性!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SigureMo 我看源码对应的类型提示
in_placements: list[list[dist.Placement]] | None = None
不是可选参数么There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前并不是,这是昨天刚改的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
而且
list[list[dist.Placement]] | None
是类型提示,= None
不是类型提示There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一师傅回的好快。所以这里应该改为
in_placements (list[list[dist.Placement]] | None,可选)
是么There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以这样