Skip to content

Commit 5203ce2

Browse files
Added tensorboard support.
1 parent 12c3934 commit 5203ce2

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,27 @@ python enjoy.py --load-dir trained_models/ppo --env-name "Reacher-v2"
168168
![QbertNoFrameskip-v4](imgs/acktr_qbert.png)
169169

170170
![beamriderNoFrameskip-v4](imgs/acktr_beamrider.png)
171+
172+
## Visualization with tensorboard.
173+
174+
### Requirements
175+
176+
* [Tensorboard](https://github.com/tensorflow/tensorboard)
177+
* [tensorboardX](https://github.com/lanpa/tensorboardX)
178+
179+
### Installation of requirements
180+
181+
```bash
182+
pip install tensorboard
183+
pip install tensorboardX
184+
```
185+
186+
### Using tensorboard to visualize training
187+
188+
```bash
189+
python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 1 --num-processes 8 --num-steps 128 --num-mini-batch 4 --vis-interval 1 --log-interval 1 --tensorboard-logdir "/tmp/tfboard"
190+
tensorboard --logdir "/tmp/tfboard"
191+
```
192+
193+
In a browser open [localhost:6006](http://localhost:6006). Note that a new folder is created every time training is
194+
started with the current timestamp.

arguments.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def get_args():
6363
help='enable visdom visualization')
6464
parser.add_argument('--port', type=int, default=8097,
6565
help='port to run the server on (default: 8097)')
66+
parser.add_argument('--tensorboard-logdir', default=None,
67+
help='logs to tensorboard in the specified directory')
6668
args = parser.parse_args()
6769

6870
args.cuda = not args.no_cuda and torch.cuda.is_available()

main.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def main():
5858
viz = Visdom(port=args.port)
5959
win = None
6060

61+
tensorboard_writer = None
62+
if args.tensorboard_logdir is not None:
63+
from tensorboardX import SummaryWriter
64+
import time, os, datetime
65+
ts_str = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')
66+
tensorboard_writer = SummaryWriter(log_dir=os.path.join(args.tensorboard_logdir, ts_str))
67+
6168
envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
6269
args.gamma, args.log_dir, args.add_timestep, device, False)
6370

@@ -150,9 +157,19 @@ def main():
150157
np.mean(episode_rewards),
151158
np.median(episode_rewards),
152159
np.min(episode_rewards),
153-
np.max(episode_rewards), dist_entropy,
160+
np.max(episode_rewards),
161+
dist_entropy,
154162
value_loss, action_loss))
155163

164+
if tensorboard_writer is not None:
165+
tensorboard_writer.add_scalar("mean reward", np.mean(episode_rewards), total_num_steps)
166+
tensorboard_writer.add_scalar("median reward", np.median(episode_rewards), total_num_steps)
167+
tensorboard_writer.add_scalar("min reward", np.min(episode_rewards), total_num_steps)
168+
tensorboard_writer.add_scalar("max reward", np.max(episode_rewards), total_num_steps)
169+
tensorboard_writer.add_scalar("dist entropy", dist_entropy, total_num_steps)
170+
tensorboard_writer.add_scalar("value loss", value_loss, total_num_steps)
171+
tensorboard_writer.add_scalar("action loss", action_loss, total_num_steps)
172+
156173
if (args.eval_interval is not None
157174
and len(episode_rewards) > 1
158175
and j % args.eval_interval == 0):

0 commit comments

Comments
 (0)