Skip to content

Commit 31596b8

Browse files
Parallelize value functions and J reward (#159)
* Optimize compute J, gae and montecarlo adv * Update test episode and value function * Fix dataset parse * Fix gae sign * Fix a2c test
1 parent 63579d7 commit 31596b8

File tree

9 files changed

+286
-41
lines changed

9 files changed

+286
-41
lines changed

mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
5757
)
5858

5959
def fit(self, dataset):
60-
state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch')
60+
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')
6161

6262
v, adv = compute_advantage_montecarlo(self._V, state, next_state,
63-
reward, absorbing,
63+
reward, absorbing, last,
6464
self.mdp_info.gamma)
6565
self._V.fit(state, v, **self._critic_fit_params)
6666

mushroom_rl/core/array_backend.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ def shape(array):
171171
@staticmethod
172172
def full(shape, value):
173173
raise NotImplementedError
174-
174+
175+
@staticmethod
176+
def nonzero(array):
177+
raise NotImplementedError
178+
179+
@staticmethod
180+
def repeat(array, repeats):
181+
raise NotImplementedError
175182

176183
class NumpyBackend(ArrayBackend):
177184
@staticmethod
@@ -188,7 +195,12 @@ def to_numpy(array):
188195

189196
@staticmethod
190197
def to_torch(array):
191-
return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device())
198+
if array is None:
199+
return None
200+
else:
201+
if array.dtype == np.float64:
202+
array = array.astype(np.float32)
203+
return torch.from_numpy(array).to(TorchUtils.get_device())
192204

193205
@staticmethod
194206
def convert_to_backend(cls, array):
@@ -303,6 +315,14 @@ def shape(array):
303315
@staticmethod
304316
def full(shape, value):
305317
return np.full(shape, value)
318+
319+
@staticmethod
320+
def nonzero(array):
321+
return np.flatnonzero(array)
322+
323+
@staticmethod
324+
def repeat(array, repeats):
325+
return np.repeat(array, repeats)
306326

307327

308328
class TorchBackend(ArrayBackend):
@@ -443,6 +463,14 @@ def shape(array):
443463
@staticmethod
444464
def full(shape, value):
445465
return torch.full(shape, value)
466+
467+
@staticmethod
468+
def nonzero(array):
469+
return torch.nonzero(array)
470+
471+
@staticmethod
472+
def repeat(array, repeats):
473+
return torch.repeat_interleave(array, repeats)
446474

447475
class ListBackend(ArrayBackend):
448476

mushroom_rl/core/dataset.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ._impl import *
1313

14+
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes
1415

1516
class DatasetInfo(Serializable):
1617
def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype,
@@ -473,22 +474,19 @@ def compute_J(self, gamma=1.):
473474
The cumulative discounted reward of each episode in the dataset.
474475
475476
"""
476-
js = list()
477-
478-
j = 0.
479-
episode_steps = 0
480-
for i in range(len(self)):
481-
j += gamma ** episode_steps * self.reward[i]
482-
episode_steps += 1
483-
if self.last[i] or i == len(self) - 1:
484-
js.append(j)
485-
j = 0.
486-
episode_steps = 0
487-
488-
if len(js) == 0:
489-
js = [0.]
490-
491-
return self._array_backend.from_list(js)
477+
r_ep = split_episodes(self.last, self.reward)
478+
479+
if len(r_ep.shape) == 1:
480+
r_ep = r_ep.unsqueeze(0)
481+
if hasattr(r_ep, 'device'):
482+
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
483+
else:
484+
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype)
485+
486+
for k in range(r_ep.shape[1]):
487+
js += gamma ** k * r_ep[..., k]
488+
489+
return js
492490

493491
def compute_metrics(self, gamma=1.):
494492
"""

mushroom_rl/rl_utils/value_functions.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
2+
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes
23

3-
4-
def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
4+
def compute_advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
55
"""
66
Function to estimate the advantage and new value function target
77
over a dataset. The value function is estimated using rollouts
@@ -24,18 +24,21 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
2424
"""
2525
with torch.no_grad():
2626
r = r.squeeze()
27-
q = torch.zeros(len(r))
2827
v = V(s).squeeze()
2928

30-
q_next = V(ss[-1]).squeeze().item()
31-
for rev_k in range(len(r)):
32-
k = len(r) - rev_k - 1
33-
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
34-
q[k] = q_next
29+
r_ep, absorbing_ep, ss_ep = split_episodes(last, r, absorbing, ss)
30+
q_ep = torch.zeros_like(r_ep, dtype=torch.float32)
31+
q_next_ep = V(ss_ep[..., -1, :]).squeeze()
32+
33+
for rev_k in range(r_ep.shape[-1]):
34+
k = r_ep.shape[-1] - rev_k - 1
35+
q_next_ep = r_ep[..., k] + gamma * q_next_ep * (1 - absorbing_ep[..., k].int())
36+
q_ep[..., k] = q_next_ep
3537

38+
q = unsplit_episodes(last, q_ep)
3639
adv = q - v
37-
return q[:, None], adv[:, None]
3840

41+
return q[:, None], adv[:, None]
3942

4043
def compute_advantage(V, s, ss, r, absorbing, gamma):
4144
"""
@@ -97,13 +100,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam):
97100
with torch.no_grad():
98101
v = V(s)
99102
v_next = V(ss)
100-
gen_adv = torch.empty_like(v)
101-
for rev_k in range(len(v)):
102-
k = len(v) - rev_k - 1
103-
if last[k] or rev_k == 0:
104-
gen_adv[k] = r[k] - v[k]
105-
if not absorbing[k]:
106-
gen_adv[k] += gamma * v_next[k]
103+
104+
v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing)
105+
gen_adv_ep = torch.zeros_like(v_ep)
106+
for rev_k in range(v_ep.shape[-1]):
107+
k = v_ep.shape[-1] - rev_k - 1
108+
if rev_k == 0:
109+
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k]
107110
else:
108-
gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1]
111+
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1]
112+
113+
gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1)
114+
109115
return gen_adv + v, gen_adv

mushroom_rl/utils/episodes.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from mushroom_rl.core.array_backend import ArrayBackend
2+
3+
def split_episodes(last, *arrays):
4+
"""
5+
Split a array from shape (n_steps) to (n_episodes, max_episode_steps).
6+
"""
7+
backend = ArrayBackend.get_array_backend_from(last)
8+
9+
if last.sum().item() <= 1:
10+
return arrays if len(arrays) > 1 else arrays[0]
11+
12+
row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend)
13+
episodes_arrays = []
14+
15+
for array in arrays:
16+
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None)
17+
18+
array_ep[row_idx, colum_idx] = array
19+
episodes_arrays.append(array_ep)
20+
21+
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]
22+
23+
def unsplit_episodes(last, *episodes_arrays):
24+
"""
25+
Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps).
26+
"""
27+
28+
if last.sum().item() <= 1:
29+
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]
30+
31+
row_idx, colum_idx, _, _ = _get_episode_idx(last)
32+
arrays = []
33+
34+
for episode_array in episodes_arrays:
35+
array = episode_array[row_idx, colum_idx]
36+
arrays.append(array)
37+
38+
return arrays if len(arrays) > 1 else arrays[0]
39+
40+
def _get_episode_idx(last, backend=None):
41+
if backend is None:
42+
backend = ArrayBackend.get_array_backend_from(last)
43+
44+
n_episodes = last.sum()
45+
last_idx = backend.nonzero(last).squeeze()
46+
first_steps = backend.from_list([last_idx[0] + 1])
47+
if hasattr(last, 'device'):
48+
first_steps = first_steps.to(last.device)
49+
episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]])
50+
max_episode_steps = episode_steps.max()
51+
52+
start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1])
53+
range_n_episodes = backend.arange(0, n_episodes, dtype=int)
54+
range_len = backend.arange(0, last.shape[0], dtype=int)
55+
if hasattr(last, 'device'):
56+
range_n_episodes = range_n_episodes.to(last.device)
57+
range_len = range_len.to(last.device)
58+
row_idx = backend.repeat(range_n_episodes, episode_steps)
59+
colum_idx = range_len - start_idx[row_idx]
60+
61+
return row_idx, colum_idx, n_episodes, max_episode_steps

tests/algorithms/test_a2c.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_a2c():
7575
agent = learn_a2c()
7676

7777
w = agent.policy.get_weights()
78-
w_test = np.array([0.9382279 , -1.8847059 , -0.13790752, -0.00786441])
78+
w_test = np.array([ 0.9389272 ,-1.8838323 ,-0.13710725,-0.00668973])
7979

8080
assert np.allclose(w, w_test)
8181

@@ -95,3 +95,5 @@ def test_a2c_save(tmpdir):
9595
print(save_attr, load_attr)
9696

9797
tu.assert_eq(save_attr, load_attr)
98+
99+
test_a2c()

tests/core/test_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,4 @@ def test_dataset_loading(tmpdir):
128128

129129
assert len(dataset.info) == len(new_dataset.info)
130130
for key in dataset.info:
131-
assert np.array_equal(dataset.info[key], new_dataset.info[key])
132-
133-
131+
assert np.array_equal(dataset.info[key], new_dataset.info[key])
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
from mushroom_rl.policy import DeterministicPolicy
3+
from mushroom_rl.environments.segway import Segway
4+
from mushroom_rl.core import Core, Agent
5+
from mushroom_rl.approximators import Regressor
6+
from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator
7+
from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo
8+
9+
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes
10+
11+
def test_compute_advantage_montecarlo():
12+
def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
13+
with torch.no_grad():
14+
r = r.squeeze()
15+
q = torch.zeros(len(r))
16+
v = V(s).squeeze()
17+
18+
for rev_k in range(len(r)):
19+
k = len(r) - rev_k - 1
20+
if last[k] or rev_k == 0:
21+
q_next = V(ss[k]).squeeze().item()
22+
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
23+
q[k] = q_next
24+
25+
adv = q - v
26+
return q[:, None], adv[:, None]
27+
28+
torch.manual_seed(42)
29+
_value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99)
30+
31+
def test_compute_gae():
32+
def gae(V, s, ss, r, absorbing, last, gamma, lam):
33+
with torch.no_grad():
34+
v = V(s)
35+
v_next = V(ss)
36+
gen_adv = torch.empty_like(v)
37+
for rev_k in range(len(v)):
38+
k = len(v) - rev_k - 1
39+
if last[k] or rev_k == 0:
40+
gen_adv[k] = r[k] - v[k]
41+
if not absorbing[k]:
42+
gen_adv[k] += gamma * v_next[k]
43+
else:
44+
gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1]
45+
return gen_adv + v, gen_adv
46+
47+
torch.manual_seed(42)
48+
_value_functions_tester(compute_gae, gae, 0.99, 0.95)
49+
50+
def _value_functions_tester(test_fun, correct_fun, *args):
51+
mdp = Segway()
52+
V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}})
53+
54+
state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10)
55+
56+
correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
57+
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)
58+
59+
assert torch.allclose(v, correct_v)
60+
assert torch.allclose(adv, correct_adv)
61+
62+
V.fit(state, correct_v)
63+
64+
correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
65+
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)
66+
67+
assert torch.allclose(v, correct_v)
68+
assert torch.allclose(adv, correct_adv)
69+
70+
def _get_episodes(mdp, n_episodes=100):
71+
mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0)
72+
73+
approximator = Regressor(LinearApproximator,
74+
input_shape=mdp.info.observation_space.shape,
75+
output_shape=mdp.info.action_space.shape,
76+
weights=mu)
77+
78+
policy = DeterministicPolicy(approximator)
79+
80+
agent = Agent(mdp.info, policy)
81+
core = Core(agent, mdp)
82+
dataset = core.evaluate(n_episodes=n_episodes)
83+
84+
return dataset.parse(to='torch')
85+
86+
class Net(torch.nn.Module):
87+
def __init__(self, input_shape, output_shape, **kwargs):
88+
super().__init__()
89+
self._q = torch.nn.Linear(input_shape[0], output_shape[0])
90+
91+
def forward(self, x):
92+
return self._q(x.float())

0 commit comments

Comments
 (0)