(Adding Vision Currently)
Currently everything is done through pip. TODO for conda env.
- Clone this repo
pip install -r requirements.txt
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U numba
Change the environment setup and hyper-parameter settings in server_run.py
, currently the config is:
config = {
"env_name": 'walker',
"algo_name": "ppo",
"task_name": "gap",
"num_timesteps": 10_000_000,
"num_evals": 1000,
"eval_every": 10_000,
"episode_length": 1000,
"num_envs": 512,
"batch_size": 512,
"num_minibatches": 32,
"num_updates_per_batch": 2,
"unroll_length": 5,
}
Caveat: On run.ai
cluster with Nvidia A40, we can only use the num_envs = 512
.
Use the followings script to run the training.
python server_run.py