Skip to content

Commit 3768469

Browse files
authored
Fixes template generator (#2161)
# Description Fix template generator: - Fix `rsl_rl` agent configuration - Don't list `skrl`'s multi-agent algorithms for single-agent tasks - Don't list `rsl_rl` and `sb3` for multi-agent tasks - Update docs to include usage steps for the generated internal task ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task -->
1 parent da5618b commit 3768469

File tree

6 files changed

+139
-28
lines changed

6 files changed

+139
-28
lines changed

docs/source/overview/developer-guide/template.rst

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,22 @@ Here are some general commands to get started with it:
6868

6969
* Install the project (in editable mode).
7070

71-
.. code:: bash
71+
.. tab-set::
72+
:sync-group: os
7273

73-
python -m pip install -e source/<given-project-name>
74+
.. tab-item:: :icon:`fa-brands fa-linux` Linux
75+
:sync: linux
76+
77+
.. code-block:: bash
78+
79+
python -m pip install -e source/<given-project-name>
80+
81+
.. tab-item:: :icon:`fa-brands fa-windows` Windows
82+
:sync: windows
83+
84+
.. code-block:: batch
85+
86+
python -m pip install -e source\<given-project-name>
7487
7588
* List the tasks available in the project.
7689

@@ -79,14 +92,90 @@ Here are some general commands to get started with it:
7992
If the task names change, it may be necessary to update the search pattern ``"Template-"``
8093
(in the ``scripts/list_envs.py`` file) so that they can be listed.
8194

82-
.. code:: bash
95+
.. tab-set::
96+
:sync-group: os
97+
98+
.. tab-item:: :icon:`fa-brands fa-linux` Linux
99+
:sync: linux
100+
101+
.. code-block:: bash
102+
103+
python scripts/list_envs.py
104+
105+
.. tab-item:: :icon:`fa-brands fa-windows` Windows
106+
:sync: windows
107+
108+
.. code-block:: batch
83109
84-
python scripts/list_envs.py
110+
python scripts\list_envs.py
85111
86112
* Run a task.
87113

88-
.. code:: bash
114+
.. tab-set::
115+
:sync-group: os
89116

90-
python scripts/<specific-rl-library>/train.py --task=<Task-Name>
117+
.. tab-item:: :icon:`fa-brands fa-linux` Linux
118+
:sync: linux
119+
120+
.. code-block:: bash
121+
122+
python scripts/<specific-rl-library>/train.py --task=<Task-Name>
123+
124+
.. tab-item:: :icon:`fa-brands fa-windows` Windows
125+
:sync: windows
126+
127+
.. code-block:: batch
128+
129+
python scripts\<specific-rl-library>\train.py --task=<Task-Name>
91130
92131
For more details, please follow the instructions in the generated project's ``README.md`` file.
132+
133+
Internal task usage (once generated)
134+
---------------------------------------
135+
136+
Once the internal task is generated, it will be available along with the rest of the Isaac Lab tasks.
137+
138+
Here are some general commands to get started with it:
139+
140+
.. note::
141+
142+
If Isaac Lab is not installed in a conda environment or in a (virtual) Python environment, use ``./isaaclab.sh -p``
143+
(or ``isaaclab.bat -p`` on Windows) instead of ``python`` to run the commands below.
144+
145+
* List the tasks available in Isaac Lab.
146+
147+
.. tab-set::
148+
:sync-group: os
149+
150+
.. tab-item:: :icon:`fa-brands fa-linux` Linux
151+
:sync: linux
152+
153+
.. code-block:: bash
154+
155+
python scripts/environments/list_envs.py
156+
157+
.. tab-item:: :icon:`fa-brands fa-windows` Windows
158+
:sync: windows
159+
160+
.. code-block:: batch
161+
162+
python scripts\environments\list_envs.py
163+
164+
* Run a task.
165+
166+
.. tab-set::
167+
:sync-group: os
168+
169+
.. tab-item:: :icon:`fa-brands fa-linux` Linux
170+
:sync: linux
171+
172+
.. code-block:: bash
173+
174+
python scripts/reinforcement_learning/<specific-rl-library>/train.py --task=<Task-Name>
175+
176+
.. tab-item:: :icon:`fa-brands fa-windows` Windows
177+
:sync: windows
178+
179+
.. code-block:: batch
180+
181+
python scripts\reinforcement_learning\<specific-rl-library>\train.py --task=<Task-Name>

tools/template/cli.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55

66
import enum
7-
import glob
87
import os
98
from collections.abc import Callable
109

1110
import rich.console
1211
import rich.table
13-
from common import ROOT_DIR, TEMPLATE_DIR
14-
from generator import generate
12+
from common import ROOT_DIR
13+
from generator import generate, get_algorithms_per_rl_library
1514
from InquirerPy import inquirer, separator
1615

1716

@@ -144,16 +143,6 @@ class State(str, enum.Enum):
144143
No = "[red]no[/red]"
145144

146145

147-
def _get_algorithms_per_rl_library():
148-
data = {"rl_games": [], "rsl_rl": [], "skrl": [], "sb3": []}
149-
for file in glob.glob(os.path.join(TEMPLATE_DIR, "agents", "*_cfg")):
150-
for rl_library in data.keys():
151-
basename = os.path.basename(file).replace("_cfg", "")
152-
if basename.startswith(f"{rl_library}_"):
153-
data[rl_library].append(basename.replace(f"{rl_library}_", "").upper())
154-
return data
155-
156-
157146
def main() -> None:
158147
"""Main function to run template generation from CLI."""
159148
cli_handler = CLIHandler()
@@ -207,10 +196,12 @@ def main() -> None:
207196
default=supported_workflows,
208197
)
209198
workflow = [{"name": item.split(" | ")[0].lower(), "type": item.split(" | ")[1].lower()} for item in workflow]
199+
single_agent_workflow = [item for item in workflow if item["type"] == "single-agent"]
200+
multi_agent_workflow = [item for item in workflow if item["type"] == "multi-agent"]
210201

211202
# RL library
212203
rl_library_algorithms = []
213-
algorithms_per_rl_library = _get_algorithms_per_rl_library()
204+
algorithms_per_rl_library = get_algorithms_per_rl_library()
214205
# - show supported RL libraries and features
215206
rl_library_table = rich.table.Table(title="Supported RL libraries")
216207
rl_library_table.add_column("RL/training feature", no_wrap=True)
@@ -219,25 +210,29 @@ def main() -> None:
219210
rl_library_table.add_column("skrl")
220211
rl_library_table.add_column("sb3")
221212
rl_library_table.add_row("ML frameworks", "PyTorch", "PyTorch", "PyTorch, JAX", "PyTorch")
213+
rl_library_table.add_row("Relative performance", "~1X", "~1X", "~1X", "~0.03X")
222214
rl_library_table.add_row(
223215
"Algorithms",
224216
", ".join(algorithms_per_rl_library.get("rl_games", [])),
225217
", ".join(algorithms_per_rl_library.get("rsl_rl", [])),
226218
", ".join(algorithms_per_rl_library.get("skrl", [])),
227219
", ".join(algorithms_per_rl_library.get("sb3", [])),
228220
)
229-
rl_library_table.add_row("Relative performance", "~1X", "~1X", "~1X", "~0.03X")
221+
rl_library_table.add_row("Multi-agent support", State.Yes, State.No, State.Yes, State.No)
230222
rl_library_table.add_row("Distributed training", State.Yes, State.No, State.Yes, State.No)
231223
rl_library_table.add_row("Vectorized training", State.Yes, State.Yes, State.Yes, State.No)
232224
rl_library_table.add_row("Fundamental/composite spaces", State.No, State.No, State.Yes, State.No)
233225
cli_handler.output_table(rl_library_table)
234226
# - prompt for RL libraries
235-
supported_rl_libraries = ["rl_games", "rsl_rl", "skrl", "sb3"]
227+
supported_rl_libraries = (
228+
["rl_games", "rsl_rl", "skrl", "sb3"] if len(single_agent_workflow) else ["rl_games", "skrl"]
229+
)
236230
selected_rl_libraries = cli_handler.get_choices(
237231
cli_handler.input_checkbox("RL library:", choices=[*supported_rl_libraries, "---", "all"]),
238232
default=supported_rl_libraries,
239233
)
240234
# - prompt for algorithms per RL library
235+
algorithms_per_rl_library = get_algorithms_per_rl_library(len(single_agent_workflow), len(multi_agent_workflow))
241236
for rl_library in selected_rl_libraries:
242237
algorithms = algorithms_per_rl_library.get(rl_library, [])
243238
if len(algorithms) > 1:

tools/template/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55

66
import os
77

8+
# paths
89
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
910
TASKS_DIR = os.path.join(ROOT_DIR, "source", "isaaclab_tasks", "isaaclab_tasks")
1011
TEMPLATE_DIR = os.path.join(ROOT_DIR, "tools", "template", "templates")
12+
13+
# RL algorithms
14+
SINGLE_AGENT_ALGORITHMS = ["AMP", "PPO"]
15+
MULTI_AGENT_ALGORITHMS = ["IPPO", "MAPPO"]

tools/template/generator.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datetime import datetime
1212

1313
import jinja2
14-
from common import ROOT_DIR, TASKS_DIR, TEMPLATE_DIR
14+
from common import MULTI_AGENT_ALGORITHMS, ROOT_DIR, SINGLE_AGENT_ALGORITHMS, TASKS_DIR, TEMPLATE_DIR
1515

1616
jinja_env = jinja2.Environment(
1717
loader=jinja2.FileSystemLoader(TEMPLATE_DIR),
@@ -260,6 +260,28 @@ def _external(specification: dict) -> None:
260260
print("-" * 80)
261261

262262

263+
def get_algorithms_per_rl_library(single_agent: bool = True, multi_agent: bool = True):
264+
assert single_agent or multi_agent, "At least one of 'single_agent' or 'multi_agent' must be True"
265+
data = {"rl_games": [], "rsl_rl": [], "skrl": [], "sb3": []}
266+
# get algorithms
267+
for file in glob.glob(os.path.join(TEMPLATE_DIR, "agents", "*_cfg")):
268+
for rl_library in data.keys():
269+
basename = os.path.basename(file).replace("_cfg", "")
270+
if basename.startswith(f"{rl_library}_"):
271+
algorithm = basename.replace(f"{rl_library}_", "").upper()
272+
assert (
273+
algorithm in SINGLE_AGENT_ALGORITHMS or algorithm in MULTI_AGENT_ALGORITHMS
274+
), f"{algorithm} algorithm is not listed in the supported algorithms"
275+
if single_agent and algorithm in SINGLE_AGENT_ALGORITHMS:
276+
data[rl_library].append(algorithm)
277+
if multi_agent and algorithm in MULTI_AGENT_ALGORITHMS:
278+
data[rl_library].append(algorithm)
279+
# remove duplicates and sort
280+
for rl_library in data.keys():
281+
data[rl_library] = sorted(list(set(data[rl_library])))
282+
return data
283+
284+
263285
def generate(specification: dict) -> None:
264286
"""Generate the project/task.
265287

tools/template/templates/agents/rsl_rl_ppo_cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, R
99

1010

1111
@configclass
12-
class CartpolePPORunnerCfg(RslRlOnPolicyRunnerCfg):
12+
class PPORunnerCfg(RslRlOnPolicyRunnerCfg):
1313
num_steps_per_env = 16
1414
max_iterations = 150
1515
save_interval = 50

tools/template/templates/tasks/__init__task

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ gym.register(
2727
{% for algorithm in rl_library.algorithms %}
2828
{# configuration file #}
2929
{% if rl_library.name == "rsl_rl" %}
30-
{% set agent_config = rl_library.name ~ "_" ~ algorithm ~ "_cfg:" ~ algorithm|upper ~ "RunnerCfg" %}
30+
{% set agent_config = "." ~ rl_library.name ~ "_" ~ algorithm ~ "_cfg:" ~ algorithm|upper ~ "RunnerCfg" %}
3131
{% else %}
32-
{% set agent_config = rl_library.name ~ "_" ~ algorithm ~ "_cfg.yaml" %}
32+
{% set agent_config = ":" ~ rl_library.name ~ "_" ~ algorithm ~ "_cfg.yaml" %}
3333
{% endif %}
3434
{# library configuration #}
3535
{% if algorithm == "ppo" %}
36-
"{{ rl_library.name }}_cfg_entry_point": f"{agents.__name__}:{{ agent_config }}",
36+
"{{ rl_library.name }}_cfg_entry_point": f"{agents.__name__}{{ agent_config }}",
3737
{% else %}
38-
"{{ rl_library.name }}_{{ algorithm }}_cfg_entry_point": f"{agents.__name__}:{{ agent_config }}",
38+
"{{ rl_library.name }}_{{ algorithm }}_cfg_entry_point": f"{agents.__name__}{{ agent_config }}",
3939
{% endif %}
4040
{% endfor %}
4141
{% endfor %}

0 commit comments

Comments
 (0)