Skip to content

Commit 9f664f2

Browse files
committed
adding classes for tabular models in templates
1 parent 9ad52fc commit 9f664f2

File tree

2 files changed

+86
-5
lines changed

2 files changed

+86
-5
lines changed

pomdp_py/utils/templates.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def __init__(self, epsilon=1e-12):
6868
self.epsilon = epsilon
6969

7070
def probability(self, next_state, state, action):
71-
"""According to problem spec, the world resets once
72-
action is open-left/open-right. Otherwise, stays the same"""
7371
if self.sample(state, action) == next_state:
7472
return 1.0 - self.epsilon
7573
else:
@@ -87,8 +85,6 @@ def __init__(self, epsilon=1e-12):
8785
self.epsilon = epsilon
8886

8987
def probability(self, observation, next_state, action):
90-
"""According to problem spec, the world resets once
91-
action is open-left/open-right. Otherwise, stays the same"""
9288
if self.sample(next_state, action) == observation:
9389
return 1.0 - self.epsilon
9490
else:
@@ -122,3 +118,88 @@ def get_all_actions(self, state=None, history=None):
122118

123119
def rollout(self, state, history=None):
124120
return random.sample(self.actions, 1)[0]
121+
122+
123+
# Tabular models
124+
class TabularTransitionModel(pomdp_py.TransitionModel):
125+
"""This tabular transition model is built given a dictionary that maps a tuple
126+
(state, action, next_state) to a probability. This model assumes that the
127+
given `weights` is complete, that is, it specifies the probability of all
128+
state-action-nextstate combinations
129+
"""
130+
def __init__(self, weights):
131+
self.weights = weights
132+
self._states = set()
133+
for s, _, sp in weights:
134+
self._states.add(s)
135+
self._states.add(sp)
136+
137+
def probability(self, next_state, state, action):
138+
if (state, action, next_state) in self.weights:
139+
return self.weights[(state, action, next_state)]
140+
raise ValueError("The transition probability for"\
141+
f"{(state, action, next_state)} is not defined")
142+
143+
def sample(self, state, action):
144+
next_states = list(self._states)
145+
probs = [self.probability(next_state, state, action)
146+
for next_state in next_states]
147+
return random.choices(next_states, weights=probs, k=1)[0]
148+
149+
def get_all_states(self):
150+
return self._states
151+
152+
153+
class TabularObservationModel(pomdp_py.ObservationModel):
154+
"""This tabular observation model is built given a dictionary that maps a tuple
155+
(next_state, action, observation) to a probability. This model assumes that the
156+
given `weights` is complete.
157+
"""
158+
def __init__(self, weights):
159+
self.weights = weights
160+
self._observations = set()
161+
for _, _, z in weights:
162+
self._observations.add(z)
163+
164+
def probability(self, observation, next_state, action):
165+
"""observation is emitted from state"""
166+
if (next_state, action, observation) in self.weights:
167+
return self.weights[(next_state, action, observation)]
168+
elif (next_state, observation) in self.weights:
169+
return self.weights[(next_state, observation)]
170+
raise ValueError("The observation probability for"
171+
f"{(next_state, action, observation)} or {(next_state, observation)}"
172+
"is not defined")
173+
174+
def sample(self, next_state, action):
175+
observations = list(self._observations)
176+
probs = [self.probability(observation, next_state, action)
177+
for observation in observations]
178+
return random.choices(observations, weights=probs, k=1)[0]
179+
180+
def get_all_observations(self):
181+
return self._observations
182+
183+
184+
class TabularRewardModel(pomdp_py.RewardModel):
185+
"""This tabular reward model is built given a dictionary that maps a state or a
186+
tuple (state, action), or (state, action, next_state) to a probability. This
187+
model assumes that the given `rewards` is complete.
188+
"""
189+
def __init__(self, rewards):
190+
self.rewards = rewards
191+
192+
def sample(self, state, action, *args):
193+
if state in self.rewards:
194+
return self.rewards[state]
195+
elif (state, action) in self.rewards:
196+
return self.rewards[(state, action)]
197+
else:
198+
if len(args) > 0:
199+
next_state = args[0]
200+
if (state, action, next_state) in self.rewards:
201+
return self.rewards[(state, action, next_state)]
202+
203+
raise ValueError("The reward is undefined for"
204+
f"state={state}, action={action}"
205+
f"next_state={args}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def build_extensions(pkg_name, major_submodules):
3434

3535
setup(name='pomdp-py',
3636
packages=find_packages(),
37-
version='1.3.0.1',
37+
version='1.3.1',
3838
description='Python POMDP Library.',
3939
long_description=long_description,
4040
long_description_content_type="text/x-rst",

0 commit comments

Comments
 (0)