-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Hello, I have some question about the reconstruction error of the VQ-VAE.
I loaded the model from https://huggingface.co/VQ-VLA/vq-vla-weight/tree/main/action_tokenizer_weight and tested on libero_10_no_noops, the average L1 error is around 0.03-0.04.
Additionally, after visualizing several trajectories, I've noticed they deviate noticeably from the ground truth.
I wonder if these error values are within expected ranges? More importantly, how significantly might this reconstruction error impact performance when deployed on a physical robot?
Thanks!

P.S. The dots are actual data. The curves are interpolated using b-spline.
Attachment: trajectory_batch3_sample1.html
Here is the code snippet I use to visualize results:
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import plotly.graph_objects as go
import torch
import torch.distributed as dist
import torch.nn.functional as F
from scipy.interpolate import splev, splprep
from torch.utils.data import DataLoader
sys.path.append(os.path.abspath("."))
from prismatic.action_vqvae import ActionVQVAELossWrapper
from prismatic.vla.datasets import RLDSActionBatchTransform, VqVAERLDSDataset
def init_distributed_mode_simple():
"""Initialize distributed mode for single GPU evaluation."""
if not dist.is_initialized():
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', init_method='env://', world_size=1, rank=0)
def interpolate_trajectory_bspline(traj, mask, k=2, num_points=2000):
"""Interpolate trajectory using B-spline."""
B, T, D = traj.shape
interp_traj = torch.zeros(B, num_points, D, device=traj.device)
for b in range(B):
valid_points = traj[b][mask[b].bool()]
if len(valid_points) < k + 1:
# Not enough points for spline, use linear interpolation
t = np.linspace(0, 1, len(valid_points))
t_new = np.linspace(0, 1, num_points)
for d in range(D):
interp_traj[b, :, d] = torch.tensor(
np.interp(t_new, t, valid_points[:, d].cpu().numpy()),
device=traj.device,
dtype=traj.dtype
)
else:
# Use B-spline
points = valid_points.cpu().numpy().T
try:
tck, _ = splprep(points, s=0, k=min(k, len(valid_points)-1))
u_new = np.linspace(0, 1, num_points)
interp_points = splev(u_new, tck)
interp_traj[b] = torch.tensor(
np.array(interp_points).T,
device=traj.device,
dtype=traj.dtype
)
except Exception as e:
print(f"Spline interpolation failed for batch {b}: {e}")
# Fallback to linear interpolation
t = np.linspace(0, 1, len(valid_points))
t_new = np.linspace(0, 1, num_points)
for d in range(D):
interp_traj[b, :, d] = torch.tensor(
np.interp(t_new, t, valid_points[:, d].cpu().numpy()),
device=traj.device,
dtype=traj.dtype
)
return interp_traj, torch.ones(B, num_points, device=traj.device, dtype=torch.bool)
def visualize_trajectory(gt_traj, recon_traj, save_path, k=3, sample_idx=0):
"""Visualize GT vs reconstructed trajectory."""
# Handle batch dimension properly
if gt_traj.dim() == 3: # If batch dimension exists
gt_sample = gt_traj[sample_idx : sample_idx + 1] # Keep batch dimension
recon_sample = recon_traj[sample_idx : sample_idx + 1]
else:
gt_sample = gt_traj.unsqueeze(0)
recon_sample = recon_traj.unsqueeze(0)
B, T, D = gt_sample.shape
mask = torch.ones(B, T, dtype=torch.bool, device=gt_sample.device)
# Interpolate trajectories
gt_interp, _ = interpolate_trajectory_bspline(gt_sample, mask, k=k, num_points=2000)
recon_interp, _ = interpolate_trajectory_bspline(recon_sample, mask, k=k, num_points=2000)
# Convert to numpy and get first sample (after interpolation)
xyz_gt = gt_interp[0, :, :3].cpu().numpy()
xyz_recon = recon_interp[0, :, :3].cpu().numpy()
# Original points
gt_orig = gt_sample[0, :, :3].cpu()
recon_orig = recon_sample[0, :, :3].cpu()
print(f"l1 distance: {F.l1_loss(gt_orig, recon_orig)}")
gt_orig = gt_orig.numpy()
recon_orig = recon_orig.numpy()
traces = []
# Ground Truth
traces.append(
go.Scatter3d(
x=xyz_gt[:, 0],
y=xyz_gt[:, 1],
z=xyz_gt[:, 2],
mode="lines",
line=dict(color="blue", width=3),
name="GT Trajectory",
)
)
traces.append(
go.Scatter3d(
x=gt_orig[:, 0],
y=gt_orig[:, 1],
z=gt_orig[:, 2],
mode="markers",
marker=dict(color="blue", size=3),
name="GT Points",
)
)
# Reconstructed
traces.append(
go.Scatter3d(
x=xyz_recon[:, 0],
y=xyz_recon[:, 1],
z=xyz_recon[:, 2],
mode="lines",
line=dict(color="orange", width=3),
name="Reconstructed Trajectory",
)
)
traces.append(
go.Scatter3d(
x=recon_orig[:, 0],
y=recon_orig[:, 1],
z=recon_orig[:, 2],
mode="markers",
marker=dict(color="orange", size=3),
name="Reconstructed Points",
)
)
# Start/End markers
traces.extend(
[
go.Scatter3d(
x=[xyz_gt[0, 0]],
y=[xyz_gt[0, 1]],
z=[xyz_gt[0, 2]],
mode="markers",
marker=dict(color="green", size=8),
name="Start",
),
go.Scatter3d(
x=[xyz_gt[-1, 0]],
y=[xyz_gt[-1, 1]],
z=[xyz_gt[-1, 2]],
mode="markers",
marker=dict(color="red", size=8),
name="End",
),
]
)
# Calculate ranges with equal aspect ratio
all_points = np.vstack([xyz_gt, xyz_recon])
xyz_min = all_points.min(axis=0)
xyz_max = all_points.max(axis=0)
# Find the maximum range across all dimensions
ranges = xyz_max - xyz_min
max_range = max(ranges)
# Calculate centers
centers = (xyz_max + xyz_min) / 2
# Set equal ranges for all axes with 0.05 spacing/padding
spacing = 0.05
half_range = max_range / 2 + spacing
x_range = [centers[0] - half_range, centers[0] + half_range]
y_range = [centers[1] - half_range, centers[1] + half_range]
z_range = [centers[2] - half_range, centers[2] + half_range]
fig = go.Figure(data=traces)
fig.update_layout(
title=f"GT vs Reconstructed Trajectory (Sample {sample_idx})",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode="cube", # Equal aspect ratio
xaxis=dict(
range=x_range,
dtick=spacing, # Set tick spacing
gridwidth=1,
showgrid=True,
),
yaxis=dict(
range=y_range,
dtick=spacing, # Set tick spacing
gridwidth=1,
showgrid=True,
),
zaxis=dict(
range=z_range,
dtick=spacing, # Set tick spacing
gridwidth=1,
showgrid=True,
),
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5) # Better viewing angle
),
),
showlegend=True,
width=1000,
height=800,
)
fig.write_html(save_path)
print(f"Saved visualization to {save_path}")
def build_model(args):
"""Build and load the VQ-VAE model."""
model = ActionVQVAELossWrapper(
args.vqvae_config_path,
use_action_type_pe=args.use_action_type_pe,
use_time_pe=args.use_time_pe,
checkpoint_path=args.checkpoint_path,
resume=True,
)
return model
def main():
parser = argparse.ArgumentParser(description="Visualize VQ-VAE action reconstruction")
parser.add_argument("--vqvae_config_path", required=True, type=str,
help="Path to VQ-VAE config file")
parser.add_argument("--checkpoint_path", required=True, type=str,
help="Path to model checkpoint")
parser.add_argument("--data_root_dir", required=True, type=str,
help="Root directory of dataset")
parser.add_argument("--dataset_name", default="libero_10_no_noops", type=str,
help="Name of dataset")
parser.add_argument("--output_dir", default="visualizations", type=str,
help="Output directory for visualizations")
parser.add_argument("--num_samples", default=5, type=int,
help="Number of samples to visualize")
parser.add_argument("--batch_size", default=1, type=int,
help="Batch size for processing")
parser.add_argument("--device", default="cuda", type=str,
help="Device to use (cuda/cpu)")
parser.add_argument("--use_action_type_pe", action="store_true",
help="Use action type positional encoding")
parser.add_argument("--use_time_pe", action="store_true",
help="Use time positional encoding")
args = parser.parse_args()
# Check device availability
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, using CPU")
args.device = "cpu"
# Initialize distributed mode for dataset
if args.device == "cuda":
init_distributed_mode_simple()
# Create output directory
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# Build and load model
print("Loading model...")
model = build_model(args)
model.to(args.device)
model.eval()
# Load dataset
print("Loading dataset...")
batch_transform = RLDSActionBatchTransform()
dataset = VqVAERLDSDataset(
args.data_root_dir,
args.dataset_name,
batch_transform,
window_size=5,
shuffle_buffer_size=10000,
only_action=True
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=0,
drop_last=False
)
# Process samples
total_rec_loss = 0
total_l1_loss = 0
total_xyz_l1_loss = 0
num_batches = 0
print(f"Processing {args.num_samples} samples...")
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= args.num_samples:
break
try:
actions = batch["actions"].to(args.device) # B, T, D
# Forward pass through model
commit_loss, rec_loss, total_loss = model(actions)
# Get reconstructed actions
latents = model.vqvae.encode(actions).latents
B = latents.shape[0]
state_rep_flat = latents.view(B, -1, latents.size(-1))
state_rep_flat, _, _ = model.vqvae.vq_layer(state_rep_flat)
state_vq = state_rep_flat.view(B, -1)
reconstructed = model.vqvae.decode(state_vq)
# Calculate losses
l1_loss = F.l1_loss(reconstructed, actions)
l2_loss = F.mse_loss(reconstructed, actions)
xyz_l1_loss = F.l1_loss(reconstructed[..., :3], actions[..., :3])
# Verify loss consistency
if not torch.allclose(l2_loss, rec_loss, rtol=1e-5):
print(f"Warning: MSE loss mismatch - computed: {l2_loss.item():.6f}, model: {rec_loss.item():.6f}")
# Accumulate losses
total_rec_loss += rec_loss.item()
total_l1_loss += l1_loss.item()
total_xyz_l1_loss += xyz_l1_loss.item()
num_batches += 1
# Visualize first few samples in batch
for i in range(min(3, B)): # Visualize up to 3 samples per batch
save_path = os.path.join(args.output_dir, f"trajectory_batch{batch_idx}_sample{i}.html")
visualize_trajectory(actions, reconstructed, save_path, sample_idx=i)
# Print metrics
print(f"Batch {batch_idx}: Rec Loss = {rec_loss.item():.6f}, L1 Loss = {l1_loss.item():.6f}, xyz_l1_loss: {xyz_l1_loss.item():.6f}")
except Exception as e:
print(f"Error processing batch {batch_idx}: {e}")
continue
# Print average losses
if num_batches > 0:
avg_rec_loss = total_rec_loss / num_batches
avg_l1_loss = total_l1_loss / num_batches
avg_xyz_l1_loss = total_xyz_l1_loss / num_batches
print(f"\nAverage L1 Loss: {avg_l1_loss:.6f}")
print(f"Average L2 Loss: {avg_rec_loss:.6f}")
print(f"Average XYZ L1 Loss: {avg_xyz_l1_loss:.6f}")
else:
print("No batches were successfully processed")
# Clean up distributed
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
Metadata
Metadata
Assignees
Labels
No labels