Skip to content

新增API local_map #71804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 24, 2025
Merged

新增API local_map #71804

merged 18 commits into from
Apr 24, 2025

Conversation

zty-king
Copy link
Contributor

@zty-king zty-king commented Mar 20, 2025

PR Category

Auto Parallel

PR Types

Others

Description

新增API local_map

1、相关背景

在分布式训练场景中,经常需要将分布式张量(dist_tensor)传递给为仅仅能处理普通张量(dense_tensor)或者必须以本地视角处理本地张量的函数。为了简化这个过程,需要提供一个工具函数来处理分布式张量到普通张量的转换,以及将函数处理的结果重新加上分布式属性。local_map API 就是为了解决这个问题而设计的。

2、功能目标

local_map 函数的主要功能是允许用户将分布式张量传递给为普通张量编写的函数。它实现了以下目标:

  1. 本地分片提取: 从分布式张量中提取本地分片数据
  2. 函数应用: 将用户函数应用于提取出的本地分片
  3. 结果包装: 根据指定的切分标记方式将结果重新包装为分布式张量
  4. 分布验证: 验证输入输出是否符合要求
  5. 自动reshard: 支持在需要时自动对输入张量进行reshard操作

3、意义

为 Paddle 分布式训练提供更便捷的张量处理方式,使得用户可以轻松地在分布式环境中复用为普通张量编写的函数。

4、常见使用场景

  1. 带 mask 的 loss 计算:需要在每张卡上独立计算 masked token 的 loss

  2. MoE (混合专家模型)相关计算:

  • aux_loss 计算:基于每张卡上专家分配到的局部 token 数进行计算

  • z_loss 计算:对每张卡上的 logits 独立计算 z_loss

  • 张量 reshape 操作:在局部维度上进行 shape 变换

  1. 其他需要保持局部计算语义的场景

5、local_map相比较LocalLayer的优化点

  1. 从结构上看,local_map是一个函数装饰器,直接包装普通函数,无需像LocalLayer那样继承Layer类去使用,无需管理Layer的状态,使用起来更方便,逻辑更清晰,相对用户使用比较友好。
  2. 支持混合输入(分布式张量、普通张量、以及一些函数运算中必须要使用到非tensor的数值参数),使用更灵活,适应性更强。
  3. 支持自动reshard逻辑,可以对输入的分布式张量进行批量的reshard,无需用户手动多次调用reshard。
  4. 支持自动推导process_mesh。
  5. 支持动态图和静态图的自适应。
  6. 可以直接处理任意Python函数,不限于Layer中的forward方法。
  7. 支持混合输出(分布式张量、普通张量、以及一些非tensor的数值),更灵活的输出处理方式。

Copy link

paddle-bot bot commented Mar 20, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 20, 2025
@@ -0,0 +1,280 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2024 -> 2025

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在仍然是2024

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

__all__ = ["local_map"]

PlacementType = Sequence[dist.Placement] | None
InputPlacements = tuple[PlacementType, ...] | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么输入和输出支持的Placements参数类型不同?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经统一

def local_map(
func: Callable,
out_placements: OutputPlacements,
in_placements: InputPlacements | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
in_placements: InputPlacements | None,
in_placements: Optional[tuple[list[dist.Placement], ...]],

没必要新创建太多类型命名,与框架中其它模块的使用习惯都不相同,反而增加用户理解成本。

Copy link
Contributor Author

@zty-king zty-king Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能用Optional,pre-commit会报错,但是已经修改了类型命名
image


for idx, arg in enumerate(flat_args):
if _is_distributed_tensor(arg):
# TODO: the current code doesn't consider the uneven sharding case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释是啥意思?

Copy link
Contributor Author

@zty-king zty-king Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除,暂时不考虑这个

redistribute_inputs: bool | None,
):
"""
:meth:`local_map` is an experimental API that allows users to pass dist_tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们没有experimental API这种说法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if arg.placements != spec:
if redistribute_inputs:
# Redistribute to input placements
arg = arg.redistribute(process_mesh, spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们有redistribute这个接口吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除并按照paddle框架改写

in_placements: InputPlacements | None,
process_mesh: ProcessMesh | None,
*,
redistribute_inputs: bool | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名要符合现有框架的习惯,我们没有redistribute这种叫法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

else:
return out

def _is_distributed_tensor(tensor) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _is_distributed_tensor(tensor) -> bool:
def is_dist_tensor(tensor) -> bool:

这是一个很基础的方法,应该放在更公共的地方,方便其它模块复用。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return pack_sequence_as(out, flat_dist_out)
else:
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果用户的输入没有dist_tensor,但指定了输出的分布式标记,这个时候直接忽略输出标记,是一种合理的行为吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经增加了这种情况的处理

if TYPE_CHECKING:
from paddle.distributed import ProcessMesh

__all__ = ["local_map"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不通过auto_parallel.local_map路径对外暴露,不应该加在这个文件的__all__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -19,6 +19,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_mlp MODULES test_mlp ENVS FLAGS_enable_pir_api=1)
py_test_modules(test_local_layer MODULES test_local_layer ENVS
FLAGS_enable_pir_api=1)
py_test_modules(test_local_map MODULES test_local_map ENVS
FLAGS_enable_pir_api=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why still need FLAGS_enable_pir_api=1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前是为了对标LocalLayer,不过现在好像确实不用标记为pir,单测都会自动走,已修改

From00
From00 previously approved these changes Apr 7, 2025
Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhiqiu
zhiqiu previously approved these changes Apr 8, 2025
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404
Copy link
Contributor

jeff41404 commented Apr 8, 2025

According to the newly added API specification of paddle, it is necessary to write the API Chinese documentation in docs repo for users to refer to the official website. please add link of docs repo PR in description above.

@zty-king
Copy link
Contributor Author

zty-king commented Apr 8, 2025

According to the newly added API specification of paddle, it is necessary to write the API Chinese documentation in docs repo for users to refer to the official website. please add link of docs repo PR in description above.

Done
https://github.com/PaddlePaddle/docs/pull/7245/files?short_path=6326c88#diff-6326c883990bd5d2a43546a1c0e4f990a2b4d2dbb02aaa54a6013b42768ca348

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 30 to 31
in_placements: list[list[dist.Placement]] | None,
process_mesh: ProcessMesh | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认值的类型注释是不是还得加一下 = None@SigureMo 一师傅看看

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这得看接口形态,如果没有默认值的话就是不需要加 = None 的啊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这得看接口形态,如果没有默认值的话就是不需要加 = None 的啊

麻烦您看看还有其他什么大问题吗,没有的话,可以先approval吗,格式我会再提一个pr修复

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

这里我没说过要改,这里具体是否要加 = None 需要看接口形态,这里完全是 @sunzhongkai588 不懂这里写的评论,我只是给他解释这一点

请 review 这里的改动,如果从接口形态上来看确实需要有默认值,且需要为 None,那么可以改这里,否则不需要改这里

Default: None

reshard_inputs (bool, optional):
the bool value indicating whether to reshard the input :dist_tensor` s when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:dist_tensor` 是不是写错了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sunzhongkai588
sunzhongkai588 previously approved these changes Apr 10, 2025
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档问题之后再提 PR 修复,一师傅别忘了回复一下

SigureMo
SigureMo previously approved these changes Apr 11, 2025
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文文档 PR 也没写么?可以下个 PR,反正 @sunzhongkai588 同意了



def local_map(
func: Callable,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

泛型写清楚内部参数类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
在实际使用当中,理论上内部参数接受任意类型变量,这也要写吗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要,Callable[..., Any]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

in_placements: list[list[dist.Placement]] | None,
process_mesh: ProcessMesh | None,
reshard_inputs: bool = False,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写清楚返回值类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

@SigureMo SigureMo Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没看到啊……

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry,理解错了您的意思,现在加了

in_placements.

Example:
>>> from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个示例代码格式不对吧,这样能正确渲染吗?就算英文能正确渲染,中文文档也无法使用 COPY-FROM copy 过去

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 75 to 85
Raises:
AssertionError: If the number of output placements does not match the number
of function outputs.

AssertionError: If a non-tensor output has a non-None placement specified.

AssertionError: If process_mesh is None and there are no dist_tensor inputs
but out_placements contains non-None values.

ValueError: If the input dist_tensor placements don't match the required
in_placements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据文档规范不写 Raises

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zty-king
Copy link
Contributor Author

zty-king commented Apr 11, 2025 via email

@zty-king
Copy link
Contributor Author

zty-king commented Apr 11, 2025 via email

Copy link

paddle-ci-bot bot commented Apr 14, 2025

Sorry to inform you that fdbfcf3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@zty-king zty-king dismissed stale reviews from zhiqiu, From00, SigureMo, and sunzhongkai588 via 798c1a6 April 14, 2025 16:18
@zty-king
Copy link
Contributor Author

中文文档 PR 也没写么?可以下个 PR,反正 @sunzhongkai588 同意了

PaddlePaddle/docs#7245 这个是对应中文文档PR

@zty-king
Copy link
Contributor Author

zty-king commented Apr 15, 2025 via email

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾 for type annotations

@zty-king
Copy link
Contributor Author

image
@jeff41404 按要求修改了一下格式,看起来需要麻烦您再review一下

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit 483c9e1 into PaddlePaddle:develop Apr 24, 2025
36 of 37 checks passed
YqGe585 pushed a commit to YqGe585/Paddle that referenced this pull request May 7, 2025
* 新增API local_map

* 修正文件格式

* 优化了local_map一些功能

* 新增reshard功能,同时兼容动静态下的local_map调用

* 修改格式规范

* 修正单测的接口命名

* 用local_map替换LocalLayer

* 单测使用local_map时的参数设置修改,reshard设置为True

* 修正单测

* 修改单测

* 修改格式

* 修改格式

* 修改测试样例格式

* 修改测试样例格式
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants