@@ -68,8 +68,6 @@ def __init__(self, epsilon=1e-12):
68
68
self .epsilon = epsilon
69
69
70
70
def probability (self , next_state , state , action ):
71
- """According to problem spec, the world resets once
72
- action is open-left/open-right. Otherwise, stays the same"""
73
71
if self .sample (state , action ) == next_state :
74
72
return 1.0 - self .epsilon
75
73
else :
@@ -87,8 +85,6 @@ def __init__(self, epsilon=1e-12):
87
85
self .epsilon = epsilon
88
86
89
87
def probability (self , observation , next_state , action ):
90
- """According to problem spec, the world resets once
91
- action is open-left/open-right. Otherwise, stays the same"""
92
88
if self .sample (next_state , action ) == observation :
93
89
return 1.0 - self .epsilon
94
90
else :
@@ -122,3 +118,88 @@ def get_all_actions(self, state=None, history=None):
122
118
123
119
def rollout (self , state , history = None ):
124
120
return random .sample (self .actions , 1 )[0 ]
121
+
122
+
123
+ # Tabular models
124
+ class TabularTransitionModel (pomdp_py .TransitionModel ):
125
+ """This tabular transition model is built given a dictionary that maps a tuple
126
+ (state, action, next_state) to a probability. This model assumes that the
127
+ given `weights` is complete, that is, it specifies the probability of all
128
+ state-action-nextstate combinations
129
+ """
130
+ def __init__ (self , weights ):
131
+ self .weights = weights
132
+ self ._states = set ()
133
+ for s , _ , sp in weights :
134
+ self ._states .add (s )
135
+ self ._states .add (sp )
136
+
137
+ def probability (self , next_state , state , action ):
138
+ if (state , action , next_state ) in self .weights :
139
+ return self .weights [(state , action , next_state )]
140
+ raise ValueError ("The transition probability for" \
141
+ f"{ (state , action , next_state )} is not defined" )
142
+
143
+ def sample (self , state , action ):
144
+ next_states = list (self ._states )
145
+ probs = [self .probability (next_state , state , action )
146
+ for next_state in next_states ]
147
+ return random .choices (next_states , weights = probs , k = 1 )[0 ]
148
+
149
+ def get_all_states (self ):
150
+ return self ._states
151
+
152
+
153
+ class TabularObservationModel (pomdp_py .ObservationModel ):
154
+ """This tabular observation model is built given a dictionary that maps a tuple
155
+ (next_state, action, observation) to a probability. This model assumes that the
156
+ given `weights` is complete.
157
+ """
158
+ def __init__ (self , weights ):
159
+ self .weights = weights
160
+ self ._observations = set ()
161
+ for _ , _ , z in weights :
162
+ self ._observations .add (z )
163
+
164
+ def probability (self , observation , next_state , action ):
165
+ """observation is emitted from state"""
166
+ if (next_state , action , observation ) in self .weights :
167
+ return self .weights [(next_state , action , observation )]
168
+ elif (next_state , observation ) in self .weights :
169
+ return self .weights [(next_state , observation )]
170
+ raise ValueError ("The observation probability for"
171
+ f"{ (next_state , action , observation )} or { (next_state , observation )} "
172
+ "is not defined" )
173
+
174
+ def sample (self , next_state , action ):
175
+ observations = list (self ._observations )
176
+ probs = [self .probability (observation , next_state , action )
177
+ for observation in observations ]
178
+ return random .choices (observations , weights = probs , k = 1 )[0 ]
179
+
180
+ def get_all_observations (self ):
181
+ return self ._observations
182
+
183
+
184
+ class TabularRewardModel (pomdp_py .RewardModel ):
185
+ """This tabular reward model is built given a dictionary that maps a state or a
186
+ tuple (state, action), or (state, action, next_state) to a probability. This
187
+ model assumes that the given `rewards` is complete.
188
+ """
189
+ def __init__ (self , rewards ):
190
+ self .rewards = rewards
191
+
192
+ def sample (self , state , action , * args ):
193
+ if state in self .rewards :
194
+ return self .rewards [state ]
195
+ elif (state , action ) in self .rewards :
196
+ return self .rewards [(state , action )]
197
+ else :
198
+ if len (args ) > 0 :
199
+ next_state = args [0 ]
200
+ if (state , action , next_state ) in self .rewards :
201
+ return self .rewards [(state , action , next_state )]
202
+
203
+ raise ValueError ("The reward is undefined for"
204
+ f"state={ state } , action={ action } "
205
+ f"next_state={ args } " )
0 commit comments