Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,20 @@ __all__ = [
##### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
In this case, you can implement the `resettable` property and `reset` method to avoid re-initialization.
In this case, you can set the `can_reset` property and implement `reset` method to avoid re-initialization.

The `can_reset` is a class property that indicates whether the workflow supports resetting.

The `resettable` property returns a boolean indicating whether the workflow supports resetting.
The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task.

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True

# some code
# ...

def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Expand All @@ -227,20 +227,18 @@ class ExampleWorkflow(Workflow):
##### Support Batch Inference

In many popular RL algorithms, multiple runs of the same task are required (e.g., GRPO). In such scenarios, you can directly use batch inference to obtain multiple responses for a single question to improve efficiency.
For this case, you can implement the `repeatable` property and `set_repeat_times` method.
For this case, you can implement the `can_repeat` property and `set_repeat_times` method.

The `can_repeat` is a class property that indicates whether the workflow supports multiple executions within the `run` method.

The `repeatable` property returns a boolean indicating whether the workflow supports multiple executions within the `run` method.
The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -279,6 +277,8 @@ class ExampleWorkflow(Workflow):
```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True

def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
Expand Down Expand Up @@ -319,18 +319,10 @@ class ExampleWorkflow(Workflow):
)
return experiences

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -364,15 +356,12 @@ trinity run --config <your_yaml_file>

#### Async Support

The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can implement the `asynchronous` property and set it to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected.
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, while other methods and properties remain unaffected.

```python
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

@property
def asynchronous(self):
return True
is_async: bool = True

async def run_async(self) -> List[Experience]:
# your async code here
Expand Down
35 changes: 11 additions & 24 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,22 +200,20 @@ __all__ = [
##### 避免重复初始化

对于较为复杂的工作流,每次重新初始化会带来额外计算开销。
此时,你可以实现 `resettable` 和 `reset` 方法以避免重复初始化。
此时,你可以设置 `can_reset` 属性并实现 `reset` 方法以避免重复初始化。

`resettable` 方法返回一个布尔值,指示工作流是否支持轻量化重置
`can_reset` 是一个类属性,表示工作流是否支持轻量化重置

`reset` 方法接受一个新的 `Task` 实例,并使用该实例更新工作流的状态。

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True

# some code
# ...

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
Expand All @@ -224,21 +222,18 @@ class ExampleWorkflow(Workflow):
##### 批量运行推理任务

当前流行的很多 RL 算法需要多次运行同一个任务(例如 GRPO)。该场景下一些简单任务可以直接通过模型批量推理来获得一个问题的多个回复以提升效率。
针对该情况,你可以实现 `repeatable` 属性以及 `set_repeat_times` 方法。
针对该情况,你可以设置 `can_repeat` 属性并实现 `set_repeat_times` 方法。

`repeatable` 属性返回一个布尔值,指示工作流是否支持在 `run` 方法内多次执行。
`can_repeat` 是一个类属性,指示工作流是否支持在 `run` 方法内多次执行。

`set_repeat_times` 方法接受两个参数:`repeat_times` 指定了在 `run` 方法内需要执行的次数,`run_id_base` 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -277,6 +272,8 @@ class ExampleWorkflow(Workflow):
```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True

def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
Expand Down Expand Up @@ -317,18 +314,10 @@ class ExampleWorkflow(Workflow):
)
return experiences

@property
def resettable(self):
return True

def reset(self, task: Task):
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")

@property
def repeatable(self) -> bool:
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -362,15 +351,13 @@ trinity run --config <your_yaml_file>

#### async 支持

本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以实现 `asynchronous` 属性并将其设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。
本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,其余方法和属性不受影响。

```python
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

@property
def asynchronous(self):
return True
is_async: bool = True

async def run_async(self) -> List[Experience]:
# your async code here
Expand Down
27 changes: 7 additions & 20 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
can_repeat: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)
Expand All @@ -30,10 +32,6 @@ def __init__(self, *, task, model, auxiliary_models):
else:
self.seconds = 10

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down Expand Up @@ -63,19 +61,13 @@ def run(self) -> List[Experience]:

@WORKFLOWS.register_module("dummy_nonrepeat_workflow")
class DummyNonRepeatWorkflow(Workflow):
can_reset: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.reset_flag = False
self.step_num = task.workflow_args.get("step_num", 1)

@property
def resettable(self):
return True

@property
def repeatable(self):
return False

def reset(self, task: Task):
self.task = task
self.reset_flag = True
Expand All @@ -95,18 +87,13 @@ def run(self) -> List[Experience]:

@WORKFLOWS.register_module("dummy_async_workflow")
class DummyAsyncWorkflow(Workflow):
can_repeat: bool = True
is_async: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.step_num = task.workflow_args.get("step_num", 1)

@property
def asynchronous(self):
return True

@property
def repeatable(self):
return True

def set_repeat_times(self, repeat_times, run_id_base):
self.repeat_times = repeat_times
self.run_id_base = run_id_base
Expand Down
Loading
Loading