Skip to content

Commit e16135e

Browse files
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio into main
2 parents be44674 + e7b7dc9 commit e16135e

File tree

6 files changed

+39
-22
lines changed

6 files changed

+39
-22
lines changed

nerfstudio/cameras/cameras.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -778,15 +778,15 @@ def _compute_rays_for_vr180(
778778

779779
return vr180_origins, directions_stack
780780

781-
for cam in cam_types:
782-
if CameraType.PERSPECTIVE.value in cam_types:
781+
for cam_type in cam_types:
782+
if CameraType.PERSPECTIVE.value == cam_type:
783783
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
784784
mask = torch.stack([mask, mask, mask], dim=0)
785785
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
786786
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
787787
directions_stack[..., 2][mask] = -1.0
788788

789-
elif CameraType.FISHEYE.value in cam_types:
789+
elif CameraType.FISHEYE.value == cam_type:
790790
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
791791
mask = torch.stack([mask, mask, mask], dim=0)
792792

@@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
803803
).float()
804804
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()
805805

806-
elif CameraType.EQUIRECTANGULAR.value in cam_types:
806+
elif CameraType.EQUIRECTANGULAR.value == cam_type:
807807
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
808808
mask = torch.stack([mask, mask, mask], dim=0)
809809

@@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
816816
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
817817
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()
818818

819-
elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
819+
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
820820
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
821821
# assign final camera origins
822822
c2w[..., :3, 3] = ods_origins_circle
823823

824-
elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
824+
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
825825
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
826826
# assign final camera origins
827827
c2w[..., :3, 3] = ods_origins_circle
828828

829-
elif CameraType.VR180_L.value in cam_types:
829+
elif CameraType.VR180_L.value == cam_type:
830830
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
831831
# assign final camera origins
832832
c2w[..., :3, 3] = vr180_origins
833833

834-
elif CameraType.VR180_R.value in cam_types:
834+
elif CameraType.VR180_R.value == cam_type:
835835
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
836836
# assign final camera origins
837837
c2w[..., :3, 3] = vr180_origins
@@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
880880
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)
881881

882882
else:
883-
raise ValueError(f"Camera type {cam} not supported.")
883+
raise ValueError(f"Camera type {cam_type} not supported.")
884884

885885
assert directions_stack.shape == (3,) + num_rays_shape + (3,)
886886

nerfstudio/engine/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
8686
"""Optionally log gradients during training"""
8787
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
8888
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
89+
start_paused: bool = False
90+
"""Whether to start the training in a paused state."""
8991

9092

9193
class Trainer:
@@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
121123
self.device += f":{local_rank}"
122124
self.mixed_precision: bool = self.config.mixed_precision
123125
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
124-
self.training_state: Literal["training", "paused", "completed"] = "training"
126+
self.training_state: Literal["training", "paused", "completed"] = (
127+
"paused" if self.config.start_paused else "training"
128+
)
125129
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
126130
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)
127131

@@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None:
361365
assert self.viewer_state and self.pipeline.datamanager.train_dataset
362366
self.viewer_state.init_scene(
363367
train_dataset=self.pipeline.datamanager.train_dataset,
364-
train_state="training",
368+
train_state=self.training_state,
365369
eval_dataset=self.pipeline.datamanager.eval_dataset,
366370
)
367371

nerfstudio/models/base_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r
214214
RGBA image.
215215
"""
216216
accumulation_name = output_name.replace("rgb", "accumulation")
217-
if (
218-
not hasattr(self, "renderer_rgb")
219-
or not hasattr(self.renderer_rgb, "background_color")
220-
or accumulation_name not in outputs
221-
):
217+
if accumulation_name not in outputs:
222218
raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}")
223219
rgb = outputs[output_name]
224220
if self.renderer_rgb.background_color == "random": # type: ignore

nerfstudio/scripts/render.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def _render_trajectory_video(
197197
outputs = pipeline.model.get_outputs_for_camera(
198198
cameras[camera_idx : camera_idx + 1], obb_box=obb_box
199199
)
200+
if rendered_output_names is not None and "rgba" in rendered_output_names:
201+
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
202+
outputs["rgba"] = rgba
200203

201204
render_image = []
202205
for rendered_output_name in rendered_output_names:
@@ -221,6 +224,8 @@ def _render_trajectory_video(
221224
.cpu()
222225
.numpy()
223226
)
227+
elif rendered_output_name == "rgba":
228+
output_image = output_image.detach().cpu().numpy()
224229
else:
225230
output_image = (
226231
colormaps.apply_colormap(
@@ -790,6 +795,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
790795
for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))):
791796
with torch.no_grad():
792797
outputs = pipeline.model.get_outputs_for_camera(camera)
798+
if self.rendered_output_names is not None and "rgba" in self.rendered_output_names:
799+
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
800+
outputs["rgba"] = rgba
793801

794802
gt_batch = batch.copy()
795803
gt_batch["rgb"] = gt_batch.pop("image")
@@ -841,11 +849,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
841849
output_image = gt_batch[output_name]
842850
else:
843851
output_image = outputs[output_name]
844-
del output_name
845852

846853
# Map to color spaces / numpy
847854
if is_raw:
848855
output_image = output_image.cpu().numpy()
856+
elif output_name == "rgba":
857+
output_image = output_image.detach().cpu().numpy()
849858
elif is_depth:
850859
output_image = (
851860
colormaps.apply_depth_colormap(

nerfstudio/viewer/viewer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def __init__(
103103
self.output_type_changed = True
104104
self.output_split_type_changed = True
105105
self.step = 0
106-
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
107-
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
106+
self.train_btn_state: Literal["training", "paused", "completed"] = (
107+
"training" if self.trainer is None else self.trainer.training_state
108+
)
109+
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
108110
self.last_move_time = 0
109111
# track the camera index that last being clicked
110112
self.current_camera_idx = 0
@@ -174,7 +176,11 @@ def __init__(
174176
)
175177
self.resume_train.on_click(lambda _: self.toggle_pause_button())
176178
self.resume_train.on_click(lambda han: self._toggle_training_state(han))
177-
self.resume_train.visible = False
179+
if self.train_btn_state == "training":
180+
self.resume_train.visible = False
181+
else:
182+
self.pause_train.visible = False
183+
178184
# Add buttons to toggle training image visibility
179185
self.hide_images = self.viser_server.gui.add_button(
180186
label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None

nerfstudio/viewer_legacy/server/viewer_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ def __init__(
116116
self.output_type_changed = True
117117
self.output_split_type_changed = True
118118
self.step = 0
119-
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
120-
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
119+
self.train_btn_state: Literal["training", "paused", "completed"] = (
120+
"training" if self.trainer is None else self.trainer.training_state
121+
)
122+
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
121123

122124
self.camera_message = None
123125

0 commit comments

Comments
 (0)