diff --git a/tutorial/tutorial_visualization.ipynb b/tutorial/tutorial_visualization.ipynb index 47ce79d..95eac56 100644 --- a/tutorial/tutorial_visualization.ipynb +++ b/tutorial/tutorial_visualization.ipynb @@ -190,6 +190,103 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize trajectories in the camera frame\n", + "Script to visualize the traj in te camera frame, supports the cases when the trajectory overlaps over 2 cameras e.g. front and left." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# cast to the camera \n", + "def get_pixels(scene, agent, cam='front',human=False):\n", + " frame_idx = scene.scene_metadata.num_history_frames - 1\n", + " if cam == 'front':\n", + " camera = scene.frames[frame_idx].cameras.cam_f0\n", + " elif cam == 'left':\n", + " camera = scene.frames[frame_idx].cameras.cam_l0\n", + " else:\n", + " camera = scene.frames[frame_idx].cameras.cam_r0\n", + " if human:\n", + " agent_trajectory = scene.get_future_trajectory()\n", + " else:\n", + " agent_trajectory = agent.compute_trajectory(scene.get_agent_input())\n", + " trajectory = agent_trajectory.poses\n", + " # transformation matrices \n", + " R_c2l = camera.sensor2lidar_rotation # camera-to-lidar (3x3)\n", + " T_c2l = camera.sensor2lidar_translation # camera-to-lidar (3,)\n", + " intrinsics = camera.intrinsics\n", + " # add z \n", + " trajectory_xyz = np.hstack([trajectory[:, :2], np.zeros((trajectory.shape[0], 1))]) # (N, 3)\n", + " \n", + " # lidar_to_camera = inverse of camera_to_lidar\n", + " R_l2c = R_c2l.T # inverse of rotation\n", + " T_l2c = -R_l2c @ T_c2l # inverse of translation\n", + " \n", + " # apply to points\n", + " trajectory_cam = (R_l2c @ trajectory_xyz.T).T + T_l2c # shape: (N, 3)\n", + " in_front = trajectory_cam[:, 2] > 0\n", + " points_cam = trajectory_cam[in_front]\n", + " \n", + " # project to image plane using intrinsics\n", + " K = intrinsics # 3x3\n", + " projected = (K @ points_cam.T).T # shape: (N, 3)\n", + " \n", + " # pixel coordinates\n", + " pixels = projected[:, :2] / projected[:, 2:3] # divide x and y by z\n", + " print(pixels)\n", + " image = camera.image # shape (H, W, 3), \n", + " img_h, img_w = image.shape[:2]\n", + " valid = (\n", + " (pixels[:, 0] >= 0) & (pixels[:, 0] < img_w) &\n", + " (pixels[:, 1] >= 0) & (pixels[:, 1] < img_h)\n", + " )\n", + " pixels_clipped = pixels[valid]\n", + " return pixels_clipped\n", + " \n", + "def plot_traj_img(scene, pixels_clipped, cam='front',ground_truth_clipped=None):\n", + " frame_idx = scene.scene_metadata.num_history_frames - 1\n", + " if cam == 'front':\n", + " camera = scene.frames[frame_idx].cameras.cam_f0\n", + " elif cam == 'left':\n", + " camera = scene.frames[frame_idx].cameras.cam_l0\n", + " else:\n", + " camera = scene.frames[frame_idx].cameras.cam_r0\n", + " image = camera.image\n", + " plt.figure(figsize=(10, 6))\n", + " plt.imshow(image, alpha=0.6)\n", + " plt.scatter(pixels_clipped[:, 0], pixels_clipped[:, 1], c='red', s=40, marker='o')\n", + " plt.plot(pixels_clipped[:,0],pixels_clipped[:,1], color='red', linewidth=2, marker='o', markersize=5)\n", + " #gt \n", + " if ground_truth_clipped is not None:\n", + " plt.scatter(ground_truth_clipped[:, 0], ground_truth_clipped[:, 1], c='green', s=40, marker='o')\n", + " plt.plot(ground_truth_clipped[:,0],ground_truth_clipped[:,1], color='green', linewidth=2, marker='o', markersize=5)\n", + " \n", + " plt.title(\"Projected Waypoints onto Camera Image\")\n", + " plt.axis(\"off\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + " \n", + "# example use\n", + "# pixels_clipped = get_pixels(scene, agent)\n", + "# gt = get_pixels(scene, agent,cam='front',human=True)\n", + "# plot_traj_img(scene, pixels_clipped, cam='front',ground_truth_clipped=gt)\n", + "\n", + "# pixels_clipped = get_pixels(scene, agent,'left')\n", + "# gt = get_pixels(scene, agent,cam='left',human=True)\n", + "\n", + "# plot_traj_img(scene, pixels_clipped, cam='left',ground_truth_clipped=gt)\n", + "\n", + "# print(pixels_clipped)\n", + "# Plot\n" + ] + }, { "cell_type": "markdown", "metadata": {},