Skip to content

CAS-CLab/MaxEntDP

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Maximum Entropy RL with Diffusion Policy (MaxEntDP)

Overview

This repository provides the implementation of the MaxEntDP algorithm for the paper "Maximum Entropy Reinforcement Learning with Diffusion Policy".

Installation

To get started, you need to install the required dependencies.

conda create -n MaxEntDP python=3.9
conda activate MaxEntDP
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .

Getting Started

To reproduce the results in the paper, navigate to the respective example directories and execute the provided training script:

cd examples/states
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name HalfCheetah-v3 --config.temp 0.2
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name Humanoid-v3 --config.temp 0.02
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name Ant-v3 --config.temp 0.05
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name Walker2d-v3 --config.temp 0.01
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name Hopper-v3 --config.temp 0.05
XLA_PYTHON_CLIENT_MEM_FRACTION=.1 python3 train_score_matching_online.py --config configs/max_entropy_learner_config.py --env_name Swimmer-v3 --config.temp 0.01

When running with multiple gpus, the batch size (default 256) should be divisible by the number of devices.

Important Files and Scripts

  • Main Training Script: The main training script to train a diffusion model agent using MaxEntDP. Includes options for the environment and training scenario.

  • MaxEntDP Learner: The core implementation of the MaxEntDP algorithm, including methods for creating the learner, updating critic and actor networks, and sampling actions. Note that if you want to make any changes to the learner after installation, you will need to reinstall jaxrl5 locally, by running the following from the root directory of the repository:

pip install ./

The code is built on top of the QSM implementation.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%