Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 05405ef

Browse files
author
Jules Pondard
committed
Add a RL actor-critic like algorithm for benchmarking
This algorithm uses a variant of experience replay, and one policy per option to predict.
1 parent a90403a commit 05405ef

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import numpy as np
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
import torch.utils.data
6+
import torch.nn.functional as F
7+
#import ipdb
8+
from itertools import count
9+
from collections import namedtuple
10+
from torch.distributions import Categorical
11+
import tensor_comprehensions as tc
12+
from visdom import Visdom
13+
from collections import deque
14+
from heapq import heappush, heappop
15+
16+
import my_utils
17+
18+
NB_EPOCHS = 1000
19+
BATCH_SZ = 16
20+
buff = deque()
21+
MAXI_BUFF_SZ = 50
22+
23+
(tc_code, tc_name, inp, init_input_sz) = my_utils.get_convolution_example(size_type="input", inp_sz_list=[8,2,28,28,8,1,1])
24+
25+
my_utils.computeCat(inp)
26+
my_utils.set_tc(tc_code, tc_name)
27+
NB_HYPERPARAMS, INIT_INPUT_SZ = my_utils.NB_HYPERPARAMS, my_utils.INIT_INPUT_SZ
28+
29+
viz = Visdom()
30+
win0 = viz.line(X=np.arange(NB_EPOCHS), Y=np.random.rand(NB_EPOCHS))
31+
win1 = viz.line(X=np.arange(NB_EPOCHS), Y=np.random.rand(NB_EPOCHS))
32+
33+
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
34+
35+
layer_sz = 32
36+
37+
class Predictor(nn.Module):
38+
def __init__(self, nb_inputs, nb_actions):
39+
super(Predictor, self).__init__()
40+
self.affine1 = nn.Linear(nb_inputs, layer_sz)
41+
self.affine15 = nn.Linear(layer_sz, layer_sz)
42+
self.affine2 = nn.Linear(layer_sz, nb_actions)
43+
self.affine3 = nn.Linear(layer_sz, 1)
44+
45+
self.W = nn.Linear(nb_inputs, nb_inputs)
46+
47+
def forward(self, x):
48+
#ipdb.set_trace()
49+
#x = F.softmax(self.W(x), dim=-1) * x #attention mecanism
50+
tmp1 = F.relu(self.affine1(x))
51+
#tmp1 = F.relu(self.affine15(tmp1))
52+
out_action = F.softmax(self.affine2(tmp1), dim=-1)
53+
out_value = self.affine3(tmp1)
54+
return out_action, out_value
55+
56+
class FullNetwork(nn.Module):
57+
def __init__(self, nb_hyperparams, init_input_sz):
58+
super(FullNetwork, self).__init__()
59+
self.nb_hyperparams = nb_hyperparams
60+
self.init_input_sz = init_input_sz
61+
self.nets = [Predictor(init_input_sz + i, int(my_utils.cat_sz[i])) for i in range(nb_hyperparams)]
62+
self.nets = nn.ModuleList(self.nets)
63+
64+
def select_action(self, x, i, out_sz):
65+
geps = 0.1
66+
proba = np.random.rand()
67+
probs, state_value = self.nets[i](x)
68+
if(proba <= geps):
69+
probs = torch.FloatTensor([1./out_sz]*out_sz)
70+
m = Categorical(probs)
71+
action = m.sample()
72+
return action.item(), m.log_prob(action), state_value
73+
74+
def forward(self, x):
75+
actions_prob = []
76+
values = []
77+
for i in range(self.nb_hyperparams):
78+
sym, action_prob, value = self.select_action(x, i, int(my_utils.cat_sz[i]))
79+
actions_prob.append(action_prob)
80+
values.append(value)
81+
x = torch.cat([x, torch.FloatTensor([sym])])
82+
return x[INIT_INPUT_SZ:], actions_prob, values
83+
84+
net = FullNetwork(NB_HYPERPARAMS, INIT_INPUT_SZ)
85+
optimizer = optim.Adam(net.parameters(), lr=0.0001)
86+
eps = np.finfo(np.float32).eps.item()
87+
88+
#print(my_utils.getAllDivs(inp))
89+
90+
def finish_episode(actions_probs, values, final_rewards):
91+
policy_losses = [[] for i in range(BATCH_SZ)]
92+
value_losses = [[] for i in range(BATCH_SZ)]
93+
final_rewards = torch.tensor(list(final_rewards))
94+
#final_rewards = (final_rewards - final_rewards.mean()) / (final_rewards.std() + eps)
95+
for batch_id in range(BATCH_SZ):
96+
for (log_prob, value) in zip(actions_probs[batch_id], values[batch_id]):
97+
reward = final_rewards[batch_id] - value.item()
98+
policy_losses[batch_id].append(-log_prob * reward)
99+
value_losses[batch_id].append(F.smooth_l1_loss(value, torch.tensor([final_rewards[batch_id]])))
100+
optimizer.zero_grad()
101+
vloss = torch.stack([torch.stack(value_losses[i]).sum() for i in range(BATCH_SZ)]).mean()
102+
ploss = torch.stack([torch.stack(policy_losses[i]).sum() for i in range(BATCH_SZ)]).mean()
103+
loss = ploss + vloss
104+
loss.backward(retain_graph=True)
105+
optimizer.step()
106+
return vloss.item(), ploss.item()
107+
108+
def add_to_buffer(actions_probs, values, reward):
109+
global buff
110+
#if(len(buff) > 0):
111+
# min_reward = np.min(np.array(buff)[:,2])
112+
# if(reward < 10*min_reward):
113+
# return
114+
if len(buff) == MAXI_BUFF_SZ:
115+
#heappop(buff)
116+
buff.popleft()
117+
#heappush(buff, (reward, actions_probs, values))
118+
buff.append((reward, actions_probs, values))
119+
120+
def select_batch():
121+
#random.sample()
122+
batch = [buff[np.random.randint(len(buff))] for i in range(BATCH_SZ)]
123+
#batch.append(buff[-1])
124+
batch=np.array(batch)
125+
return batch[:,1], batch[:,2], batch[:,0]
126+
127+
def get_best_buff():
128+
return np.max(np.array(buff)[:,0])
129+
130+
INTER_DISP = 20
131+
132+
running_reward = -0.5
133+
tab_rewards=[]
134+
tab_best=[]
135+
best=-12
136+
v_losses=[]
137+
p_losses=[]
138+
best_options = np.zeros(NB_HYPERPARAMS).astype(int)
139+
for i in range(NB_EPOCHS):
140+
rewards = []
141+
out_actions, out_probs, out_values = net(init_input_sz)
142+
#my_utils.print_opt(out_actions.numpy().astype(int))
143+
reward = my_utils.evalTime(out_actions.numpy().astype(int), prune=-1, curr_best=np.exp(-best))
144+
#reward=100*reward
145+
#reward = -((reward)/1000)
146+
reward = -np.log(reward)
147+
add_to_buffer(out_probs, out_values, reward)
148+
best_in_buffer = get_best_buff()
149+
if(i >= 20):
150+
actions_probs, values, rewards = select_batch()
151+
for j in range(1):
152+
vloss, ploss = finish_episode(actions_probs, values, rewards)
153+
v_losses.append(vloss)
154+
p_losses.append(ploss)
155+
if(best < reward or i==0):
156+
best=reward
157+
best_options = out_actions.numpy().astype(int)
158+
my_utils.print_opt(best_options)
159+
if(i==0):
160+
running_reward = reward
161+
running_reward = running_reward * 0.99 + reward * 0.01
162+
tab_rewards.append(-(running_reward))
163+
tab_best.append(-best)
164+
if i % INTER_DISP == 0:
165+
viz.line(X=np.column_stack((np.arange(i+1), np.arange(i+1))), Y=np.column_stack((np.array(tab_rewards), np.array(tab_best))), win=win0, opts=dict(legend=["Geometric run", "Best time"]))
166+
if(len(v_losses) > 0):
167+
viz.line(X=np.column_stack((np.arange(len(v_losses)), np.arange(len(v_losses)))), Y=np.column_stack((np.array(v_losses), np.array(p_losses))), win=win1, opts=dict(legend=["Value loss", "Policy loss"]))
168+
print(-running_reward)
169+
print(-best)
170+
print("Best in buffer: " + str(-best_in_buffer))
171+
172+
print("Finally, best options are:")
173+
my_utils.print_opt(best_options)

0 commit comments

Comments
 (0)