Skip to content

Commit e2a0915

Browse files
authored
bugfix: handle multiple camera types in a single batch (#3415)
1 parent d67b281 commit e2a0915

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

nerfstudio/cameras/cameras.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -778,15 +778,15 @@ def _compute_rays_for_vr180(
778778

779779
return vr180_origins, directions_stack
780780

781-
for cam in cam_types:
782-
if CameraType.PERSPECTIVE.value in cam_types:
781+
for cam_type in cam_types:
782+
if CameraType.PERSPECTIVE.value == cam_type:
783783
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
784784
mask = torch.stack([mask, mask, mask], dim=0)
785785
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
786786
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
787787
directions_stack[..., 2][mask] = -1.0
788788

789-
elif CameraType.FISHEYE.value in cam_types:
789+
elif CameraType.FISHEYE.value == cam_type:
790790
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
791791
mask = torch.stack([mask, mask, mask], dim=0)
792792

@@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
803803
).float()
804804
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
805805

806-
elif CameraType.EQUIRECTANGULAR.value in cam_types:
806+
elif CameraType.EQUIRECTANGULAR.value == cam_type:
807807
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
808808
mask = torch.stack([mask, mask, mask], dim=0)
809809

@@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
816816
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
817817
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
818818

819-
elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
819+
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
820820
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
821821
# assign final camera origins
822822
c2w[..., :3, 3] = ods_origins_circle
823823

824-
elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
824+
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
825825
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
826826
# assign final camera origins
827827
c2w[..., :3, 3] = ods_origins_circle
828828

829-
elif CameraType.VR180_L.value in cam_types:
829+
elif CameraType.VR180_L.value == cam_type:
830830
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
831831
# assign final camera origins
832832
c2w[..., :3, 3] = vr180_origins
833833

834-
elif CameraType.VR180_R.value in cam_types:
834+
elif CameraType.VR180_R.value == cam_type:
835835
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
836836
# assign final camera origins
837837
c2w[..., :3, 3] = vr180_origins
@@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
880880
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)
881881

882882
else:
883-
raise ValueError(f"Camera type {cam} not supported.")
883+
raise ValueError(f"Camera type {cam_type} not supported.")
884884

885885
assert directions_stack.shape == (3,) + num_rays_shape + (3,)
886886

0 commit comments

Comments
 (0)