Skip to content

Commit ec6e043

Browse files
authored
added float_precision argument to to_pomdp_file (#29)
1 parent b61ffbb commit ec6e043

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pomdp_py/utils/interfaces/conversion.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import xml.etree.ElementTree as ET
1010

1111
def to_pomdp_file(agent, output_path=None,
12-
discount_factor=0.95):
12+
discount_factor=0.95, float_precision=9):
1313
"""
1414
Pass in an Agent, and use its components to generate
1515
a .pomdp file to `output_path`.
@@ -30,6 +30,8 @@ def to_pomdp_file(agent, output_path=None,
3030
output_path (str): The path of the output file to write in. Optional.
3131
Default None.
3232
discount_factor (float): The discount factor
33+
float_precision (int): Number of decimals for float to str conversion.
34+
Default 6.
3335
Returns:
3436
(list, list, list): The list of states, actions, observations that
3537
are ordered in the same way as they are in the .pomdp file.
@@ -42,7 +44,7 @@ def to_pomdp_file(agent, output_path=None,
4244
except NotImplementedError:
4345
raise ValueError("S, A, O must be enumerable for a given agent to convert to .pomdp format")
4446

45-
content = "discount: %f\n" % discount_factor
47+
content = f"discount: %.{float_precision}f\n" % discount_factor
4648
content += "values: reward\n" # We only consider reward, not cost.
4749

4850
list_of_states = " ".join(str(s) for s in all_states)
@@ -62,7 +64,7 @@ def to_pomdp_file(agent, output_path=None,
6264

6365
# Starting belief state - they need to be normalized
6466
total_belief = sum(agent.belief[s] for s in all_states)
65-
content += "start: %s\n" % (" ".join(["%f" % (agent.belief[s]/total_belief)
67+
content += "start: %s\n" % (" ".join([f"%.{float_precision}f" % (agent.belief[s]/total_belief)
6668
for s in all_states]))
6769

6870
# State transition probabilities - they need to be normalized
@@ -75,7 +77,7 @@ def to_pomdp_file(agent, output_path=None,
7577
total_prob = sum(probs)
7678
for i, s_next in enumerate(all_states):
7779
prob_norm = probs[i] / total_prob
78-
content += 'T : %s : %s : %s %f\n' % (a, s, s_next, prob_norm)
80+
content += f'T : %s : %s : %s %.{float_precision}f\n' % (a, s, s_next, prob_norm)
7981

8082
# Observation probabilities - they need to be normalized
8183
for s_next in all_states:
@@ -90,15 +92,15 @@ def to_pomdp_file(agent, output_path=None,
9092
.format(s_next, a)
9193
for i, o in enumerate(all_observations):
9294
prob_norm = probs[i] / total_prob
93-
content += 'O : %s : %s : %s %f\n' % (a, s_next, o, prob_norm)
95+
content += f'O : %s : %s : %s %.{float_precision}f\n' % (a, s_next, o, prob_norm)
9496

9597
# Immediate rewards
9698
for s in all_states:
9799
for a in all_actions:
98100
for s_next in all_states:
99101
# We will take the argmax reward, which works for deterministic rewards.
100102
r = agent.reward_model.sample(s, a, s_next)
101-
content += 'R : %s : %s : %s : * %f\n' % (a, s, s_next, r)
103+
content += f'R : %s : %s : %s : * %.{float_precision}f\n' % (a, s, s_next, r)
102104

103105
if output_path is not None:
104106
with open(output_path, "w") as f:

0 commit comments

Comments
 (0)