Skip to content

[Question] About reconstruction error #6

@luocfprime

Description

@luocfprime

Hello, I have some question about the reconstruction error of the VQ-VAE.

I loaded the model from https://huggingface.co/VQ-VLA/vq-vla-weight/tree/main/action_tokenizer_weight and tested on libero_10_no_noops, the average L1 error is around 0.03-0.04.

Additionally, after visualizing several trajectories, I've noticed they deviate noticeably from the ground truth.

I wonder if these error values are within expected ranges? More importantly, how significantly might this reconstruction error impact performance when deployed on a physical robot?

Thanks!

Image

P.S. The dots are actual data. The curves are interpolated using b-spline.

Attachment: trajectory_batch3_sample1.html

Here is the code snippet I use to visualize results:

import argparse
import os
import sys
from pathlib import Path

import numpy as np
import plotly.graph_objects as go
import torch
import torch.distributed as dist
import torch.nn.functional as F
from scipy.interpolate import splev, splprep
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath("."))

from prismatic.action_vqvae import ActionVQVAELossWrapper
from prismatic.vla.datasets import RLDSActionBatchTransform, VqVAERLDSDataset


def init_distributed_mode_simple():
    """Initialize distributed mode for single GPU evaluation."""
    if not dist.is_initialized():
        os.environ['RANK'] = '0'
        os.environ['WORLD_SIZE'] = '1'
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group(backend='nccl', init_method='env://', world_size=1, rank=0)


def interpolate_trajectory_bspline(traj, mask, k=2, num_points=2000):
    """Interpolate trajectory using B-spline."""
    B, T, D = traj.shape
    interp_traj = torch.zeros(B, num_points, D, device=traj.device)

    for b in range(B):
        valid_points = traj[b][mask[b].bool()]
        if len(valid_points) < k + 1:
            # Not enough points for spline, use linear interpolation
            t = np.linspace(0, 1, len(valid_points))
            t_new = np.linspace(0, 1, num_points)
            for d in range(D):
                interp_traj[b, :, d] = torch.tensor(
                    np.interp(t_new, t, valid_points[:, d].cpu().numpy()),
                    device=traj.device,
                    dtype=traj.dtype
                )
        else:
            # Use B-spline
            points = valid_points.cpu().numpy().T
            try:
                tck, _ = splprep(points, s=0, k=min(k, len(valid_points)-1))
                u_new = np.linspace(0, 1, num_points)
                interp_points = splev(u_new, tck)
                interp_traj[b] = torch.tensor(
                    np.array(interp_points).T,
                    device=traj.device,
                    dtype=traj.dtype
                )
            except Exception as e:
                print(f"Spline interpolation failed for batch {b}: {e}")
                # Fallback to linear interpolation
                t = np.linspace(0, 1, len(valid_points))
                t_new = np.linspace(0, 1, num_points)
                for d in range(D):
                    interp_traj[b, :, d] = torch.tensor(
                        np.interp(t_new, t, valid_points[:, d].cpu().numpy()),
                        device=traj.device,
                        dtype=traj.dtype
                    )

    return interp_traj, torch.ones(B, num_points, device=traj.device, dtype=torch.bool)


def visualize_trajectory(gt_traj, recon_traj, save_path, k=3, sample_idx=0):
    """Visualize GT vs reconstructed trajectory."""
    # Handle batch dimension properly
    if gt_traj.dim() == 3:  # If batch dimension exists
        gt_sample = gt_traj[sample_idx : sample_idx + 1]  # Keep batch dimension
        recon_sample = recon_traj[sample_idx : sample_idx + 1]
    else:
        gt_sample = gt_traj.unsqueeze(0)
        recon_sample = recon_traj.unsqueeze(0)

    B, T, D = gt_sample.shape
    mask = torch.ones(B, T, dtype=torch.bool, device=gt_sample.device)

    # Interpolate trajectories
    gt_interp, _ = interpolate_trajectory_bspline(gt_sample, mask, k=k, num_points=2000)
    recon_interp, _ = interpolate_trajectory_bspline(recon_sample, mask, k=k, num_points=2000)

    # Convert to numpy and get first sample (after interpolation)
    xyz_gt = gt_interp[0, :, :3].cpu().numpy()
    xyz_recon = recon_interp[0, :, :3].cpu().numpy()

    # Original points
    gt_orig = gt_sample[0, :, :3].cpu()
    recon_orig = recon_sample[0, :, :3].cpu()

    print(f"l1 distance: {F.l1_loss(gt_orig, recon_orig)}")

    gt_orig = gt_orig.numpy()
    recon_orig = recon_orig.numpy()

    traces = []

    # Ground Truth
    traces.append(
        go.Scatter3d(
            x=xyz_gt[:, 0],
            y=xyz_gt[:, 1],
            z=xyz_gt[:, 2],
            mode="lines",
            line=dict(color="blue", width=3),
            name="GT Trajectory",
        )
    )
    traces.append(
        go.Scatter3d(
            x=gt_orig[:, 0],
            y=gt_orig[:, 1],
            z=gt_orig[:, 2],
            mode="markers",
            marker=dict(color="blue", size=3),
            name="GT Points",
        )
    )

    # Reconstructed
    traces.append(
        go.Scatter3d(
            x=xyz_recon[:, 0],
            y=xyz_recon[:, 1],
            z=xyz_recon[:, 2],
            mode="lines",
            line=dict(color="orange", width=3),
            name="Reconstructed Trajectory",
        )
    )
    traces.append(
        go.Scatter3d(
            x=recon_orig[:, 0],
            y=recon_orig[:, 1],
            z=recon_orig[:, 2],
            mode="markers",
            marker=dict(color="orange", size=3),
            name="Reconstructed Points",
        )
    )

    # Start/End markers
    traces.extend(
        [
            go.Scatter3d(
                x=[xyz_gt[0, 0]],
                y=[xyz_gt[0, 1]],
                z=[xyz_gt[0, 2]],
                mode="markers",
                marker=dict(color="green", size=8),
                name="Start",
            ),
            go.Scatter3d(
                x=[xyz_gt[-1, 0]],
                y=[xyz_gt[-1, 1]],
                z=[xyz_gt[-1, 2]],
                mode="markers",
                marker=dict(color="red", size=8),
                name="End",
            ),
        ]
    )

    # Calculate ranges with equal aspect ratio
    all_points = np.vstack([xyz_gt, xyz_recon])
    xyz_min = all_points.min(axis=0)
    xyz_max = all_points.max(axis=0)

    # Find the maximum range across all dimensions
    ranges = xyz_max - xyz_min
    max_range = max(ranges)

    # Calculate centers
    centers = (xyz_max + xyz_min) / 2

    # Set equal ranges for all axes with 0.05 spacing/padding
    spacing = 0.05
    half_range = max_range / 2 + spacing

    x_range = [centers[0] - half_range, centers[0] + half_range]
    y_range = [centers[1] - half_range, centers[1] + half_range]
    z_range = [centers[2] - half_range, centers[2] + half_range]

    fig = go.Figure(data=traces)
    fig.update_layout(
        title=f"GT vs Reconstructed Trajectory (Sample {sample_idx})",
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z",
            aspectmode="cube",  # Equal aspect ratio
            xaxis=dict(
                range=x_range,
                dtick=spacing,  # Set tick spacing
                gridwidth=1,
                showgrid=True,
            ),
            yaxis=dict(
                range=y_range,
                dtick=spacing,  # Set tick spacing
                gridwidth=1,
                showgrid=True,
            ),
            zaxis=dict(
                range=z_range,
                dtick=spacing,  # Set tick spacing
                gridwidth=1,
                showgrid=True,
            ),
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5)  # Better viewing angle
            ),
        ),
        showlegend=True,
        width=1000,
        height=800,
    )

    fig.write_html(save_path)
    print(f"Saved visualization to {save_path}")

def build_model(args):
    """Build and load the VQ-VAE model."""
    model = ActionVQVAELossWrapper(
        args.vqvae_config_path,
        use_action_type_pe=args.use_action_type_pe,
        use_time_pe=args.use_time_pe,
        checkpoint_path=args.checkpoint_path,
        resume=True,
    )
    return model


def main():
    parser = argparse.ArgumentParser(description="Visualize VQ-VAE action reconstruction")
    parser.add_argument("--vqvae_config_path", required=True, type=str,
                        help="Path to VQ-VAE config file")
    parser.add_argument("--checkpoint_path", required=True, type=str,
                        help="Path to model checkpoint")
    parser.add_argument("--data_root_dir", required=True, type=str,
                        help="Root directory of dataset")
    parser.add_argument("--dataset_name", default="libero_10_no_noops", type=str,
                        help="Name of dataset")
    parser.add_argument("--output_dir", default="visualizations", type=str,
                        help="Output directory for visualizations")
    parser.add_argument("--num_samples", default=5, type=int,
                        help="Number of samples to visualize")
    parser.add_argument("--batch_size", default=1, type=int,
                        help="Batch size for processing")
    parser.add_argument("--device", default="cuda", type=str,
                        help="Device to use (cuda/cpu)")
    parser.add_argument("--use_action_type_pe", action="store_true",
                        help="Use action type positional encoding")
    parser.add_argument("--use_time_pe", action="store_true",
                        help="Use time positional encoding")
    args = parser.parse_args()

    # Check device availability
    if args.device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, using CPU")
        args.device = "cpu"

    # Initialize distributed mode for dataset
    if args.device == "cuda":
        init_distributed_mode_simple()

    # Create output directory
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Build and load model
    print("Loading model...")
    model = build_model(args)
    model.to(args.device)
    model.eval()

    # Load dataset
    print("Loading dataset...")
    batch_transform = RLDSActionBatchTransform()
    dataset = VqVAERLDSDataset(
        args.data_root_dir,
        args.dataset_name,
        batch_transform,
        window_size=5,
        shuffle_buffer_size=10000,
        only_action=True
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=0,
        drop_last=False
    )

    # Process samples
    total_rec_loss = 0
    total_l1_loss = 0
    total_xyz_l1_loss = 0
    num_batches = 0

    print(f"Processing {args.num_samples} samples...")
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= args.num_samples:
                break

            try:
                actions = batch["actions"].to(args.device)  # B, T, D

                # Forward pass through model
                commit_loss, rec_loss, total_loss = model(actions)

                # Get reconstructed actions
                latents = model.vqvae.encode(actions).latents
                B = latents.shape[0]
                state_rep_flat = latents.view(B, -1, latents.size(-1))
                state_rep_flat, _, _ = model.vqvae.vq_layer(state_rep_flat)
                state_vq = state_rep_flat.view(B, -1)
                reconstructed = model.vqvae.decode(state_vq)

                # Calculate losses
                l1_loss = F.l1_loss(reconstructed, actions)
                l2_loss = F.mse_loss(reconstructed, actions)

                xyz_l1_loss = F.l1_loss(reconstructed[..., :3], actions[..., :3])

                # Verify loss consistency
                if not torch.allclose(l2_loss, rec_loss, rtol=1e-5):
                    print(f"Warning: MSE loss mismatch - computed: {l2_loss.item():.6f}, model: {rec_loss.item():.6f}")

                # Accumulate losses
                total_rec_loss += rec_loss.item()
                total_l1_loss += l1_loss.item()
                total_xyz_l1_loss += xyz_l1_loss.item()
                num_batches += 1

                # Visualize first few samples in batch
                for i in range(min(3, B)):  # Visualize up to 3 samples per batch
                    save_path = os.path.join(args.output_dir, f"trajectory_batch{batch_idx}_sample{i}.html")
                    visualize_trajectory(actions, reconstructed, save_path, sample_idx=i)

                # Print metrics
                print(f"Batch {batch_idx}: Rec Loss = {rec_loss.item():.6f}, L1 Loss = {l1_loss.item():.6f}, xyz_l1_loss: {xyz_l1_loss.item():.6f}")

            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                continue

    # Print average losses
    if num_batches > 0:
        avg_rec_loss = total_rec_loss / num_batches
        avg_l1_loss = total_l1_loss / num_batches
        avg_xyz_l1_loss = total_xyz_l1_loss / num_batches
        print(f"\nAverage L1 Loss: {avg_l1_loss:.6f}")
        print(f"Average L2 Loss: {avg_rec_loss:.6f}")
        print(f"Average XYZ L1 Loss: {avg_xyz_l1_loss:.6f}")
    else:
        print("No batches were successfully processed")

    # Clean up distributed
    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions