Skip to content

Commit b6173bc

Browse files
committed
fix bugs
1 parent 556b51b commit b6173bc

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

rl_plotter/plotter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,14 @@ def main():
106106
# OpenAI spinup's progress
107107
if args.filename == 'progress.txt' or args.filename == 'progress.csv':
108108
args.xkey = 'TotalEnvInteracts'
109-
args.ykey = ['AverageTestEpRet']
109+
if len(args.ykey) == 1:
110+
args.ykey = ['AverageTestEpRet']
110111

111112
# rl-plotter's evaluator
112113
if args.filename == 'evaluator.csv':
113114
args.xkey = 'total_steps'
114-
args.ykey = ['mean_score']
115+
if len(args.ykey) == 1:
116+
args.ykey = ['mean_score']
115117

116118
if args.save is False:
117119
args.show = True

rl_plotter/plotter_spinup.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def main():
7171
# OpenAI spinup's progress
7272
if args.filename == 'progress.txt' or args.filename == 'progress.csv':
7373
args.xkey = 'TotalEnvInteracts'
74-
args.ykey = ['AverageTestEpRet']
74+
if len(args.ykey) == 1:
75+
args.ykey = ['AverageTestEpRet']
7576

7677
# rl-plotter's evaluator
7778
if args.filename == 'evaluator.csv':
7879
args.xkey = 'total_steps'
79-
args.ykey = ['mean_score']
80+
if len(args.ykey) == 1:
81+
args.ykey = ['mean_score']
8082

8183
if args.save is False:
8284
args.show = True
@@ -86,12 +88,14 @@ def main():
8688
for result in allresults:
8789
result['data'].insert(len(result['data'].columns),'Condition1', pu.default_split_fn(result))
8890
datas.append(result['data'])
89-
pu.plot_data(data=datas, xaxis=args.xkey, value=args.ykey[0], smooth=args.smooth,
90-
legend_outside=args.legend_outside,
91-
legend_loc=args.legend_loc,
92-
legend_borderpad=args.borderpad,
93-
legend_labelspacing=args.labelspacing,
94-
font_scale=args.font_scale)
91+
for value in args.ykey:
92+
plt.figure()
93+
pu.plot_data(data=datas, xaxis=args.xkey, value=value, smooth=args.smooth,
94+
legend_outside=args.legend_outside,
95+
legend_loc=args.legend_loc,
96+
legend_borderpad=args.borderpad,
97+
legend_labelspacing=args.labelspacing,
98+
font_scale=args.font_scale)
9599
plt.title(args.title)
96100
plt.xlabel(args.xlabel)
97101
plt.ylabel(args.ylabel)

0 commit comments

Comments
 (0)