Skip to content

yihangyao/OASIS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ToRA
OASIS: Conditional Distribution Shaping for Offline Safe Reinforcement Learning

[🌐 Website][📜 Arxiv][🤗 HF Models]

Repo for "OASIS: Conditional Distribution Shaping for Offline Safe Reinforcement Learning" [NeurIPS'2024]

Email: {yihangya, zcen}[at]andrew.cmu.edu

Methods

Figure OASIS: a Data-centric approach for offline safe RL. Conditioned on the human preference, OASIS first curates an offline dataset with a conditioned diffusion data generator and learned labeling models, then trains safe RL agents with this generated dataset. We provide example checkpoints for models and curated datasets at our 🤗 huggingface repo.

Installation

This code is tested on an Ubuntu 18.04 system. To install the packages, please first create a python environment with python==3.8, then run:

cd OSRL
pip install -e .
cd ../DSRL
pip install -e .
cd ..
pip install -r requirements.txt

Method

The proposed method contains 3 steps:

(1) Leanring a set of data generator models containing a state-sequence diffusion model generator, cost/reward models and inverse dynamics model;

(2) Generating a dataset using the learned models conditioned on user's safety preference;

(3) Training an offline safe RL agent on this generated dataset.

If you want to skip the generator and labeling models training, you may go to step 3 directly and download the pre-trained models and generated dataset instead.

(Step 1) OASIS Training

To train an OASIS data generator, run:

cd OSRL/examples/train
python train_oasis.py

It will train an OASIS model for the OfflineBallCircle-v0 task using the tempting dataset. The learned models contain a data (state-sequence) generator and an inverse dynamics model. The reward model for OfflineBallCircle-v0 can be trained by running the following commands:

cd OSRL/examples/train
python train_label.py --task OfflineBallCircle-v0 --learning_mode reward

You can change the task to be [OfflineBallCircle-v0, OfflineCarCircle-v0, OfflineDroneCircle-v0, OfflineBallRun-v0, OfflineCarRun-v0, OfflineDroneRun-v0]. By default, we use OfflineBallCircle-v0.

A set of learned models's checkpoints as well as generated datasets are available online. Download them by running:

cd OASIS
git clone https://huggingface.co/YYY-45/OASIS
mkdir dataset/from_tempting
cp -r OASIS/tempting/dataset/* dataset/from_tempting

(Step 2) Dataset Generation

To generate a dataset using OASIS, run:

cd Generation
python dataset_generation.py

It will use the pre-trained OASIS model "BallCircle.pt" in the "OASIS/models" folder, and use pre-trained cost/reward models "BC_cost.pt" and "BC_reward.pt" to label the dataset. The generated dataset is saved to the "dataset" folder. The target cost limit is 20. To change the models to use, please specify the model path by: . Please make sure that the model configs are aligned.

(Step 3) RL agent Training

Our method is compitable with general offline safe RL algorithms. In this paper, we train a BCQ-Lag agent on the generated dataset. To Train a BCQ-Lag agent:

generated dataset, run:

cd OSRL/examples/train
python train_bcql.py --task OfflineBallCircle-v0

It will use the dataset saved in the "dataset" folder to train an BCQ-Lag agent. The cost limit is 20. If you want to use your own dataset for training, you may change the data path by python train_bcql.py --new_data_path [your_path] in the code above. You may also use the provided dataset of other tasks by directly setting the task names in [OfflineBallCircle-v0, OfflineCarCircle-v0, OfflineDroneCircle-v0, OfflineBallRun-v0, OfflineCarRun-v0, OfflineDroneRun-v0].

Github Reference

Bibtex

If you find our code and paper can help, please consider citing our paper as:

@article{
    yao2024oasis,
    title={OASIS: Conditional Distribution Shaping for Offline Safe Reinforcement Learning},
    author={Yao, Yihang and Cen, Zhepeng and Ding, Wenhao and Lin, Haohong and Liu, Shiqi and Zhang, Tingnan and Yu, Wenhao and Zhao, Ding},
    journal={arXiv preprint arXiv:2407.14653},
    year={2024}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published