This repository contains the code and configuration files for training a multimodal fine-tuned InstructPix2Pix
model to predict future robotic action frames. The model generates 256×256 resolution images conditioned on a current observation and textual instruction (e.g., "stack blocks", "beat the blocks with hammer"). Results achieve SSIM up to 0.98 and PSNR over 40 dB on synthetic RoboTwin tasks.
git clone https://github.com/CAI991108/robotic-frame-prediction.git
cd robotic-frame-prediction
cd instruct-pix2pix
conda env create -f environment.yaml
conda activate ip2p
bash scripts/download_checkpoints.sh
bash scripts/download_pretrained_sd.sh # Stable Diffusion v1.5
bash scripts/download_checkpoints.sh # InstructPix2Pix
- Follow RoboTwin's official guide to generate episodes for three tasks:
block_hammer_beat
block_handover
block_stack_easy
- Place generated data in
./RoboTwin_data
.
- Use the provided script to preprocess the RoboTwin dataset:
# Step 1: Extract frames and map to instructions
python ./RoboTwin/test.py --root_dir <your_RoboTwin_data_dir> --output_jsonl instructpix2pix_dataset.jsonl
# Step 2: Convert to InstructPix2Pix-compatible format
python ./instruct-pix2pix/data/instructpix2pix/data_prepare.py --input_jsonl instructpix2pix_dataset.jsonl --output_dir <your_output_dir>
- Edit the
./instruct-pix2pix/configs/train.yaml
file to set the paths for the dataset and checkpoints:
data:
params:
batch_size: 2 # Batch size for training
num_workers: 2 # Number of workers for data loading
train:
params:
path: ./data/instructpix2pix # Path to preprocessed data
python ./instruct-pix2pix/main.py \
--name default \
--base configs/train.yaml \
--train \
--gpus 0,1 # Use 2 GPUs
- Batch size:
2
per GPU (effective16
with gradient accumulation) - Learning rate:
1e-4
(AdamW optimizer) - Epochs:
100
python ./instruct-pix2pix/eval.py --ckpt logs/train_default/checkpoints/last.ckpt
- GPUs: 2× NVIDIA RTX 2080 Ti (22GB VRAM each)
- RAM: 134 GB
- Dependency Conflicts: Ensure exact versions in
requirements.txt
andenvironment.yaml
. - OOM Errors: Reduce batch size, enable
--half-precision
or setuse_ema: false
. - Dataset Paths: Verify paths in
train.yaml
anddata_prepare.py
.