Skip to content

Commit 878d023

Browse files
authored
[Feature] Implicit Q-Learning (IQL) (#933)
1 parent 45c6129 commit 878d023

File tree

9 files changed

+998
-7
lines changed

9 files changed

+998
-7
lines changed

.circleci/unittest/linux_examples/scripts/run_test.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,23 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea
111111
record_frames=4 \
112112
buffer_size=120 \
113113
rssm_hidden_dim=17
114+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
115+
total_frames=48 \
116+
init_random_frames=10 \
117+
batch_size=10 \
118+
frames_per_batch=16 \
119+
num_workers=4 \
120+
env_per_collector=2 \
121+
collector_devices=cuda:0 \
122+
mode=offline
123+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
124+
total_frames=48 \
125+
batch_size=10 \
126+
frames_per_batch=16 \
127+
num_workers=4 \
128+
env_per_collector=2 \
129+
collector_devices=cuda:0 \
130+
mode=offline
114131

115132
# With single envs
116133
python .circleci/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
@@ -196,6 +213,24 @@ python .circleci/unittest/helpers/coverage_run_parallel.py examples/dreamer/drea
196213
record_frames=4 \
197214
buffer_size=120 \
198215
rssm_hidden_dim=17
216+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/td3/td3.py \
217+
total_frames=48 \
218+
init_random_frames=10 \
219+
batch_size=10 \
220+
frames_per_batch=16 \
221+
num_workers=2 \
222+
env_per_collector=1 \
223+
mode=offline \
224+
collector_devices=cuda:0
225+
python .circleci/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
226+
total_frames=48 \
227+
batch_size=10 \
228+
frames_per_batch=16 \
229+
num_workers=2 \
230+
env_per_collector=1 \
231+
mode=offline \
232+
collector_devices=cuda:0
233+
199234

200235
coverage combine
201236
coverage xml -i

docs/source/reference/objectives.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ REDQ
4040

4141
REDQLoss
4242

43+
IQL
44+
----
45+
46+
.. autosummary::
47+
:toctree: generated/
48+
:template: rl_template_noinherit.rst
49+
50+
IQLLoss
51+
4352
PPO
4453
---
4554

examples/iql/iql_online.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import hydra
8+
9+
import numpy as np
10+
import torch
11+
import torch.cuda
12+
import tqdm
13+
from tensordict.nn import TensorDictModule
14+
from tensordict.nn.distributions import NormalParamExtractor
15+
16+
from torch import nn, optim
17+
from torchrl.collectors import SyncDataCollector
18+
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
19+
20+
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
21+
from torchrl.envs import EnvCreator, ParallelEnv
22+
from torchrl.envs.libs.gym import GymEnv
23+
from torchrl.envs.utils import set_exploration_mode
24+
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
25+
from torchrl.modules.distributions import TanhNormal
26+
27+
from torchrl.objectives import SoftUpdate
28+
from torchrl.objectives.iql import IQLLoss
29+
from torchrl.record.loggers import generate_exp_name, get_logger
30+
31+
32+
def env_maker(env_name, frame_skip=1, device="cpu", from_pixels=False):
33+
return GymEnv(
34+
env_name, "run", device=device, frame_skip=frame_skip, from_pixels=from_pixels
35+
)
36+
37+
38+
def make_replay_buffer(
39+
prb=False,
40+
buffer_size=1000000,
41+
buffer_scratch_dir="/tmp/",
42+
device="cpu",
43+
make_replay_buffer=3,
44+
):
45+
if prb:
46+
replay_buffer = TensorDictPrioritizedReplayBuffer(
47+
alpha=0.7,
48+
beta=0.5,
49+
pin_memory=False,
50+
prefetch=make_replay_buffer,
51+
storage=LazyMemmapStorage(
52+
buffer_size,
53+
scratch_dir=buffer_scratch_dir,
54+
device=device,
55+
),
56+
)
57+
else:
58+
replay_buffer = TensorDictReplayBuffer(
59+
pin_memory=False,
60+
prefetch=make_replay_buffer,
61+
storage=LazyMemmapStorage(
62+
buffer_size,
63+
scratch_dir=buffer_scratch_dir,
64+
device=device,
65+
),
66+
)
67+
return replay_buffer
68+
69+
70+
@hydra.main(version_base=None, config_path=".", config_name="online_config")
71+
def main(cfg: "DictConfig"): # noqa: F821
72+
73+
device = (
74+
torch.device("cuda:0")
75+
if torch.cuda.is_available()
76+
and torch.cuda.device_count() > 0
77+
and cfg.device == "cuda:0"
78+
else torch.device("cpu")
79+
)
80+
81+
exp_name = generate_exp_name("Online_IQL", cfg.exp_name)
82+
logger = get_logger(
83+
logger_type=cfg.logger,
84+
logger_name="iql_logging",
85+
experiment_name=exp_name,
86+
wandb_kwargs={"mode": cfg.mode},
87+
)
88+
89+
torch.manual_seed(cfg.seed)
90+
np.random.seed(cfg.seed)
91+
92+
def env_factory(num_workers):
93+
"""Creates an instance of the environment."""
94+
95+
# 1.2 Create env vector
96+
vec_env = ParallelEnv(
97+
create_env_fn=EnvCreator(lambda: env_maker(env_name=cfg.env_name)),
98+
num_workers=num_workers,
99+
)
100+
101+
return vec_env
102+
103+
# Sanity check
104+
test_env = env_factory(num_workers=5)
105+
num_actions = test_env.action_spec.shape[-1]
106+
107+
# Create Agent
108+
# Define Actor Network
109+
in_keys = ["observation"]
110+
action_spec = test_env.action_spec
111+
actor_net_kwargs = {
112+
"num_cells": [256, 256],
113+
"out_features": 2 * num_actions,
114+
"activation_class": nn.ReLU,
115+
}
116+
117+
actor_net = MLP(**actor_net_kwargs)
118+
119+
dist_class = TanhNormal
120+
dist_kwargs = {
121+
"min": action_spec.space.minimum[-1],
122+
"max": action_spec.space.maximum[-1],
123+
"tanh_loc": cfg.tanh_loc,
124+
}
125+
126+
actor_extractor = NormalParamExtractor(
127+
scale_mapping=f"biased_softplus_{cfg.default_policy_scale}",
128+
scale_lb=cfg.scale_lb,
129+
)
130+
131+
actor_net = nn.Sequential(actor_net, actor_extractor)
132+
in_keys_actor = in_keys
133+
actor_module = TensorDictModule(
134+
actor_net,
135+
in_keys=in_keys_actor,
136+
out_keys=[
137+
"loc",
138+
"scale",
139+
],
140+
)
141+
actor = ProbabilisticActor(
142+
spec=action_spec,
143+
in_keys=["loc", "scale"],
144+
module=actor_module,
145+
distribution_class=dist_class,
146+
distribution_kwargs=dist_kwargs,
147+
default_interaction_mode="random",
148+
return_log_prob=False,
149+
)
150+
151+
# Define Critic Network
152+
qvalue_net_kwargs = {
153+
"num_cells": [256, 256],
154+
"out_features": 1,
155+
"activation_class": nn.ReLU,
156+
}
157+
158+
qvalue_net = MLP(
159+
**qvalue_net_kwargs,
160+
)
161+
162+
qvalue = ValueOperator(
163+
in_keys=["action"] + in_keys,
164+
module=qvalue_net,
165+
)
166+
167+
# Define Value Network
168+
value_net_kwargs = {
169+
"num_cells": [256, 256],
170+
"out_features": 1,
171+
"activation_class": nn.ReLU,
172+
}
173+
value_net = MLP(**value_net_kwargs)
174+
value = ValueOperator(
175+
in_keys=in_keys,
176+
module=value_net,
177+
)
178+
179+
model = nn.ModuleList([actor, qvalue, value]).to(device)
180+
181+
# init nets
182+
with torch.no_grad():
183+
td = test_env.reset()
184+
td = td.to(device)
185+
actor(td)
186+
qvalue(td)
187+
value(td)
188+
189+
del td
190+
test_env.close()
191+
test_env.eval()
192+
193+
# Create IQL loss
194+
loss_module = IQLLoss(
195+
actor_network=model[0],
196+
qvalue_network=model[1],
197+
value_network=model[2],
198+
num_qvalue_nets=2,
199+
gamma=cfg.gamma,
200+
temperature=cfg.temperature,
201+
expectile=cfg.expectile,
202+
loss_function="smooth_l1",
203+
)
204+
205+
# Define Target Network Updater
206+
target_net_updater = SoftUpdate(loss_module, cfg.target_update_polyak)
207+
208+
# Make Off-Policy Collector
209+
collector = SyncDataCollector(
210+
env_factory,
211+
create_env_kwargs={"num_workers": cfg.env_per_collector},
212+
policy=model[0],
213+
frames_per_batch=cfg.frames_per_batch,
214+
max_frames_per_traj=cfg.max_frames_per_traj,
215+
total_frames=cfg.total_frames,
216+
device=cfg.device,
217+
)
218+
collector.set_seed(cfg.seed)
219+
220+
# Make Replay Buffer
221+
replay_buffer = make_replay_buffer(buffer_size=cfg.buffer_size, device=device)
222+
223+
# Optimizers
224+
params = list(loss_module.parameters())
225+
optimizer = optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
226+
227+
rewards = []
228+
rewards_eval = []
229+
230+
# Main loop
231+
target_net_updater.init_()
232+
233+
collected_frames = 0
234+
235+
pbar = tqdm.tqdm(total=cfg.total_frames)
236+
r0 = None
237+
loss = None
238+
239+
for i, tensordict in enumerate(collector):
240+
241+
# update weights of the inference policy
242+
collector.update_policy_weights_()
243+
244+
if r0 is None:
245+
r0 = tensordict["reward"].sum(-1).mean().item()
246+
pbar.update(tensordict.numel())
247+
248+
if "mask" in tensordict.keys():
249+
# if multi-step, a mask is present to help filter padded values
250+
current_frames = tensordict["mask"].sum()
251+
tensordict = tensordict[tensordict.get("mask").squeeze(-1)]
252+
else:
253+
tensordict = tensordict.view(-1)
254+
current_frames = tensordict.numel()
255+
replay_buffer.extend(tensordict.cpu())
256+
collected_frames += current_frames
257+
258+
(
259+
actor_losses,
260+
q_losses,
261+
value_losses,
262+
) = ([], [], [])
263+
# optimization steps
264+
for _ in range(cfg.frames_per_batch * int(cfg.utd_ratio)):
265+
# sample from replay buffer
266+
sampled_tensordict = replay_buffer.sample(cfg.batch_size).clone()
267+
268+
loss_td = loss_module(sampled_tensordict)
269+
270+
actor_loss = loss_td["loss_actor"]
271+
q_loss = loss_td["loss_qvalue"]
272+
value_loss = loss_td["loss_value"]
273+
274+
loss = actor_loss + q_loss + value_loss
275+
276+
optimizer.zero_grad()
277+
loss.backward()
278+
optimizer.step()
279+
280+
q_losses.append(q_loss.item())
281+
actor_losses.append(actor_loss.item())
282+
value_losses.append(value_loss.item())
283+
284+
# update qnet_target params
285+
target_net_updater.step()
286+
287+
# update priority
288+
if cfg.prb:
289+
replay_buffer.update_priority(sampled_tensordict)
290+
291+
rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
292+
train_log = {
293+
"train_reward": rewards[-1][1],
294+
"collected_frames": collected_frames,
295+
}
296+
if q_loss is not None:
297+
train_log.update(
298+
{
299+
"actor_loss": np.mean(actor_losses),
300+
"q_loss": np.mean(q_losses),
301+
"value_loss": np.mean(value_losses),
302+
}
303+
)
304+
for key, value in train_log.items():
305+
logger.log_scalar(key, value, step=collected_frames)
306+
307+
with set_exploration_mode("mean"), torch.no_grad():
308+
eval_rollout = test_env.rollout(
309+
max_steps=cfg.max_frames_per_traj,
310+
policy=model[0],
311+
auto_cast_to_device=True,
312+
).clone()
313+
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
314+
rewards_eval.append((i, eval_reward))
315+
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
316+
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
317+
if len(rewards_eval):
318+
pbar.set_description(
319+
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
320+
)
321+
322+
collector.shutdown()
323+
324+
325+
if __name__ == "__main__":
326+
main()

0 commit comments

Comments
 (0)