Skip to content

Commit e2e4c32

Browse files
[RLlib] Adjust callback validation to account for MultiCallback. (#50920)
1 parent fde33f9 commit e2e4c32

File tree

4 files changed

+164
-3
lines changed

4 files changed

+164
-3
lines changed

rllib/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,19 @@ py_test(
10701070
srcs = ["algorithms/sac/tests/test_sac.py"]
10711071
)
10721072

1073+
# --------------------------------------------------------------------
1074+
# Callback tests
1075+
# rllib/callbacks/
1076+
#
1077+
# Tag: callbacks
1078+
# --------------------------------------------------------------------
1079+
py_test(
1080+
name = "test_multicallback",
1081+
tags = ["team:rllib", "callbacks_dir"],
1082+
size = "medium",
1083+
srcs = ["callbacks/tests/test_multicallback.py"]
1084+
)
1085+
10731086
# --------------------------------------------------------------------
10741087
# ConnectorV2 tests
10751088
# rllib/connector/

rllib/algorithms/algorithm_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ray.rllib.offline.io_context import IOContext
4040
from ray.rllib.policy.policy import Policy, PolicySpec
4141
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
42-
from ray.rllib.utils import deep_update, merge_dicts
42+
from ray.rllib.utils import deep_update, force_list, merge_dicts
4343
from ray.rllib.utils.annotations import (
4444
OldAPIStack,
4545
OverrideToImplementCustomLogic_CallToSuperRecommended,
@@ -2518,9 +2518,10 @@ def callbacks(
25182518
# Check, whether given `callbacks` is a callable.
25192519
# TODO (sven): Once the old API stack is deprecated, this can also be None
25202520
# (which should then become the default value for this attribute).
2521-
if not callable(callbacks_class):
2521+
to_check = force_list(callbacks_class)
2522+
if not all(callable(c) for c in to_check):
25222523
raise ValueError(
2523-
"`config.callbacks_class` must be a callable method that "
2524+
"`config.callbacks_class` must be a callable or list of callables that "
25242525
"returns a subclass of DefaultCallbacks, got "
25252526
f"{callbacks_class}!"
25262527
)

rllib/callbacks/tests/__init__.py

Whitespace-only changes.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import unittest
2+
import ray
3+
from ray.rllib.algorithms import PPOConfig
4+
from ray.rllib.callbacks.callbacks import RLlibCallback
5+
6+
7+
class TestMultiCallback(unittest.TestCase):
8+
"""A tests suite to test the `MultiCallback`."""
9+
10+
@classmethod
11+
def setUp(cls) -> None:
12+
ray.init()
13+
14+
@classmethod
15+
def tearDown(cls) -> None:
16+
ray.shutdown()
17+
18+
def test_multicallback_with_custom_callback_function(self):
19+
"""Tests if callbacks in `MultiCallback` get executed.
20+
21+
This also tests, if multiple callbacks from different sources, i.e.
22+
`callback_class` and `on_episode_step` run correctly.
23+
"""
24+
# Define two standard `RLlibCallback`.
25+
class TestRLlibCallback1(RLlibCallback):
26+
def on_episode_step(
27+
self,
28+
*,
29+
episode,
30+
env_runner=None,
31+
metrics_logger=None,
32+
env=None,
33+
env_index,
34+
rl_module=None,
35+
worker=None,
36+
base_env=None,
37+
policies=None,
38+
**kwargs
39+
):
40+
41+
metrics_logger.log_value(
42+
"callback_1", 1, reduce="mean", clear_on_reduce=True
43+
)
44+
45+
class TestRLlibCallback2(RLlibCallback):
46+
def on_episode_step(
47+
self,
48+
*,
49+
episode,
50+
env_runner=None,
51+
metrics_logger=None,
52+
env=None,
53+
env_index,
54+
rl_module=None,
55+
worker=None,
56+
base_env=None,
57+
policies=None,
58+
**kwargs
59+
):
60+
61+
metrics_logger.log_value(
62+
"callback_2", 2, reduce="mean", clear_on_reduce=True
63+
)
64+
65+
# Define a custom callback function.
66+
def custom_on_episode_step_callback(
67+
episode,
68+
env_runner=None,
69+
metrics_logger=None,
70+
env=None,
71+
env_index=None,
72+
rl_module=None,
73+
worker=None,
74+
base_env=None,
75+
policies=None,
76+
**kwargs
77+
):
78+
79+
metrics_logger.log_value(
80+
"custom_callback", 3, reduce="mean", clear_on_reduce=True
81+
)
82+
83+
# Configure the algorithm.
84+
config = (
85+
PPOConfig()
86+
.environment("CartPole-v1")
87+
.api_stack(
88+
enable_env_runner_and_connector_v2=True,
89+
enable_rl_module_and_learner=True,
90+
)
91+
# Use the callbacks and callback function.
92+
.callbacks(
93+
callbacks_class=[TestRLlibCallback1, TestRLlibCallback2],
94+
on_episode_step=custom_on_episode_step_callback,
95+
)
96+
)
97+
98+
# Build the algorithm. At this stage, callbacks get already validated.
99+
algo = config.build()
100+
101+
# Run 10 training iteration and check, if the metrics defined in the
102+
# callbacks made it into the results. Furthermore, check, if the values are correct.
103+
for _ in range(10):
104+
results = algo.train()
105+
self.assertIn("callback_1", results["env_runners"])
106+
self.assertIn("callback_2", results["env_runners"])
107+
self.assertIn("custom_callback", results["env_runners"])
108+
self.assertAlmostEqual(results["env_runners"]["callback_1"], 1)
109+
self.assertAlmostEqual(results["env_runners"]["callback_2"], 2)
110+
self.assertAlmostEqual(results["env_runners"]["custom_callback"], 3)
111+
112+
algo.stop()
113+
114+
def test_multicallback_validation_error(self):
115+
"""Check, if the validation safeguard catches wrong `MultiCallback`s."""
116+
with self.assertRaises(ValueError):
117+
(
118+
PPOConfig()
119+
.environment("CartPole-v1")
120+
.api_stack(
121+
enable_env_runner_and_connector_v2=True,
122+
enable_rl_module_and_learner=True,
123+
)
124+
# This is wrong b/c it needs callables.
125+
.callbacks(callbacks_class=["TestRLlibCallback1", "TestRLlibCallback2"])
126+
)
127+
128+
def test_single_callback_validation_error(self):
129+
"""Tests if the validation safeguard catches wrong `RLlibCallback`s."""
130+
with self.assertRaises(ValueError):
131+
(
132+
PPOConfig()
133+
.environment("CartPole-v1")
134+
.api_stack(
135+
enable_env_runner_and_connector_v2=True,
136+
enable_rl_module_and_learner=True,
137+
)
138+
# This is wrong b/c it needs callables.
139+
.callbacks(callbacks_class="TestRLlibCallback")
140+
)
141+
142+
143+
if __name__ == "__main__":
144+
import pytest
145+
import sys
146+
147+
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)