Skip to content

Commit 64e4669

Browse files
committed
allow updating rollout policy
1 parent df59f26 commit 64e4669

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

pomdp_py/algorithms/po_rollout.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ cdef class PORollout(Planner):
1616
cpdef _search(self)
1717
cpdef _rollout(self, State state, int depth)
1818
cpdef update(self, Agent agent, Action real_action, Observation real_observation,
19-
state_transform_func=*)
19+
state_transform_func=*)
20+
21+
cpdef set_rollout_policy(self, RolloutPolicy rollout_policy)

pomdp_py/algorithms/po_rollout.pyx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ cdef class PORollout(Planner):
6363
cdef float best_reward, reward_avg, total_discounted_reward
6464
cdef set legal_actions
6565
cdef list rewards
66-
66+
6767
best_action, best_reward = None, float("-inf")
6868
legal_actions = self._agent.valid_actions(history=self._agent.history)
6969
for action in legal_actions:
@@ -77,7 +77,7 @@ cdef class PORollout(Planner):
7777
best_action = action
7878
best_reward = reward_avg
7979
return best_action, best_reward
80-
80+
8181
cpdef _rollout(self, State state, int depth):
8282
# Rollout without a tree.
8383
cdef Action action
@@ -88,7 +88,7 @@ cdef class PORollout(Planner):
8888
cdef float reward
8989
cdef int nsteps
9090
cdef tuple history = self._agent.history
91-
91+
9292
while depth <= self._max_depth:
9393
action = self._rollout_policy.rollout(state, history=history)
9494
next_state, observation, reward, nsteps = sample_generative_model(self._agent, state, action)
@@ -118,7 +118,7 @@ cdef class PORollout(Planner):
118118
agent.set_belief(particle_reinvigoration(new_belief,
119119
len(agent.init_belief.particles),
120120
state_transform_func=state_transform_func))
121-
121+
122122
@property
123123
def update_agent_belief(self):
124124
"""True if planner's update function also updates agent's
@@ -129,3 +129,10 @@ cdef class PORollout(Planner):
129129
"""clear_agent(self)"""
130130
self._agent = None # forget about current agent so that can plan for another agent.
131131
self._last_best_reward = float('-inf')
132+
133+
cpdef set_rollout_policy(self, RolloutPolicy rollout_policy):
134+
"""
135+
set_rollout_policy(self, RolloutPolicy rollout_policy)
136+
Updates the rollout policy to the given one
137+
"""
138+
self._rollout_policy = rollout_policy

pomdp_py/algorithms/po_uct.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cdef class POUCT(Planner):
4040
cpdef _rollout(self, State state, tuple history, VNode root, int depth)
4141
cpdef Action _ucb(self, VNode root)
4242
cpdef tuple _sample_generative_model(self, State state, Action action)
43+
cpdef set_rollout_policy(self, RolloutPolicy rollout_policy)
4344

4445
cdef class RolloutPolicy(PolicyModel):
4546
cpdef Action rollout(self, State state, tuple history)

pomdp_py/algorithms/po_uct.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,13 @@ cdef class POUCT(Planner):
260260
self._agent = None # forget about current agent so that can plan for another agent.
261261
self._last_num_sims = -1
262262

263+
cpdef set_rollout_policy(self, RolloutPolicy rollout_policy):
264+
"""
265+
set_rollout_policy(self, RolloutPolicy rollout_policy)
266+
Updates the rollout policy to the given one
267+
"""
268+
self._rollout_policy = rollout_policy
269+
263270
cpdef _expand_vnode(self, VNode vnode, tuple history, State state=None):
264271
cdef Action action
265272
cdef tuple preference

0 commit comments

Comments
 (0)