Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,20 @@ __all__ = [
##### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can implement the `resettable` property and `reset` method to avoid re-initialization.
In this case, you can set the `can_reset` property and implement `reset` method to avoid re-initialization.

The `can_reset` is a class property that indicates whether the workflow supports resetting.

The `resettable` property returns a boolean indicating whether the workflow supports resetting.
The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task.

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True

# some code
# ...

def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Expand All @@ -227,20 +227,18 @@ class ExampleWorkflow(Workflow):
##### Support Batch Inference

In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency.
For this case, you can implement the `repeatable` property and `set_repeat_times` method.
For this case, you can implement the `can_repeat` property and `set_repeat_times` method.

The `can_repeat` is a class property that indicates whether the workflow supports multiple executions within the `run` method.

The `repeatable` property returns a boolean indicating whether the workflow supports multiple executions within the `run` method.
The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -279,6 +277,8 @@ class ExampleWorkflow(Workflow):
```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True

def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
Expand Down Expand Up @@ -319,18 +319,10 @@ class ExampleWorkflow(Workflow):
)
return experiences

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -364,15 +356,13 @@ trinity run --config <your_yaml_file>

#### Async Support

The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can implement the `asynchronous` property and set it to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected.
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed.

```python
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

@property
def asynchronous(self):
return True
is_async: bool = True

async def run_async(self) -> List[Experience]:
# your async code here
Expand Down Expand Up @@ -458,7 +448,7 @@ explorer:

Note that each auxiliary model will independently occupy `tensor_parallel_size * engine_num` GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (`rollout_model`).

The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` instances to the `auxiliary_models` parameter of the `Workflow` initialization method. For example:
The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` or `openai.AsyncOpenAI` instances (depending on the `is_async` setting) to the `auxiliary_models` parameter of the `Workflow` initialization method. For example:

```python
class MyWorkflow(Workflow):
Expand Down
37 changes: 12 additions & 25 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,20 @@ __all__ = [
##### 避免重复初始化

对于较为复杂的工作流,每次重新初始化会带来额外计算开销。
此时,你可以实现 `resettable` `reset` 方法以避免重复初始化。
此时,你可以设置 `can_reset` 属性并实现 `reset` 方法以避免重复初始化。

`resettable` 方法返回一个布尔值,指示工作流是否支持轻量化重置
`can_reset` 是一个类属性,表示工作流是否支持轻量化重置

`reset` 方法接受一个新的 `Task` 实例,并使用该实例更新工作流的状态。

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True

# some code
# ...

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Expand All @@ -224,21 +222,18 @@ class ExampleWorkflow(Workflow):
##### 批量运行推理任务

当前流行的很多 RL 算法需要多次运行同一个任务(例如 GRPO)。该场景下一些简单任务可以直接通过模型批量推理来获得一个问题的多个回复以提升效率。
针对该情况,你可以实现 `repeatable` 属性以及 `set_repeat_times` 方法。
针对该情况,你可以设置 `can_repeat` 属性并实现 `set_repeat_times` 方法。

`repeatable` 属性返回一个布尔值,指示工作流是否支持在 `run` 方法内多次执行。
`can_repeat` 是一个类属性,指示工作流是否支持在 `run` 方法内多次执行。

`set_repeat_times` 方法接受两个参数:`repeat_times` 指定了在 `run` 方法内需要执行的次数,`run_id_base` 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -277,6 +272,8 @@ class ExampleWorkflow(Workflow):
```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True

def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
Expand Down Expand Up @@ -317,18 +314,10 @@ class ExampleWorkflow(Workflow):
)
return experiences

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -362,15 +351,13 @@ trinity run --config <your_yaml_file>

#### async 支持

本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以实现 `asynchronous` 属性并将其设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响
本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,并且初始化参数 `auxiliary_models` 也会自动变为 `List[openai.AsyncOpenAI]` 类型,其余方法和属性保持不变

```python
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

@property
def asynchronous(self):
return True
is_async: bool = True

async def run_async(self) -> List[Experience]:
# your async code here
Expand Down Expand Up @@ -458,7 +445,7 @@ explorer:
请注意,每个辅助模型会独立占用 `tensor_parallel_size * engine_num` 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(`rollout_model`)所占用的 GPU 数量。

配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 实例传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如:
配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 或 `openai.AsyncOpenAI` 实例 (取决于 `is_async`) 传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如:

```python
class MyWorkflow(Workflow):
Expand Down
27 changes: 7 additions & 20 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
can_repeat: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)
Expand All @@ -30,10 +32,6 @@ def __init__(self, *, task, model, auxiliary_models):
else:
self.seconds = 10

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -63,19 +61,13 @@ def run(self) -> List[Experience]:

@WORKFLOWS.register_module("dummy_nonrepeat_workflow")
class DummyNonRepeatWorkflow(Workflow):
can_reset: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.reset_flag = False
self.step_num = task.workflow_args.get("step_num", 1)

@property
def resettable(self):
return True

@property
def repeatable(self):
return False

def reset(self, task: Task):
self.task = task
self.reset_flag = True
Expand All @@ -95,18 +87,13 @@ def run(self) -> List[Experience]:

@WORKFLOWS.register_module("dummy_async_workflow")
class DummyAsyncWorkflow(Workflow):
can_repeat: bool = True
is_async: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)

@property
def asynchronous(self):
return True

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down
Loading