Skip to content

Commit d67b281

Browse files
authored
RGBA renderer for splatfacto (#3307)
* Update base_model.py * Update render.py * Update render.py * fixing error * Update render.py * Update render.py
1 parent f86dbe6 commit d67b281

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

nerfstudio/models/base_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r
214214
RGBA image.
215215
"""
216216
accumulation_name = output_name.replace("rgb", "accumulation")
217-
if (
218-
not hasattr(self, "renderer_rgb")
219-
or not hasattr(self.renderer_rgb, "background_color")
220-
or accumulation_name not in outputs
221-
):
217+
if accumulation_name not in outputs:
222218
raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}")
223219
rgb = outputs[output_name]
224220
if self.renderer_rgb.background_color == "random": # type: ignore

nerfstudio/scripts/render.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def _render_trajectory_video(
197197
outputs = pipeline.model.get_outputs_for_camera(
198198
cameras[camera_idx : camera_idx + 1], obb_box=obb_box
199199
)
200+
if rendered_output_names is not None and "rgba" in rendered_output_names:
201+
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
202+
outputs["rgba"] = rgba
200203

201204
render_image = []
202205
for rendered_output_name in rendered_output_names:
@@ -221,6 +224,8 @@ def _render_trajectory_video(
221224
.cpu()
222225
.numpy()
223226
)
227+
elif rendered_output_name == "rgba":
228+
output_image = output_image.detach().cpu().numpy()
224229
else:
225230
output_image = (
226231
colormaps.apply_colormap(
@@ -790,6 +795,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
790795
for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))):
791796
with torch.no_grad():
792797
outputs = pipeline.model.get_outputs_for_camera(camera)
798+
if self.rendered_output_names is not None and "rgba" in self.rendered_output_names:
799+
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
800+
outputs["rgba"] = rgba
793801

794802
gt_batch = batch.copy()
795803
gt_batch["rgb"] = gt_batch.pop("image")
@@ -841,11 +849,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
841849
output_image = gt_batch[output_name]
842850
else:
843851
output_image = outputs[output_name]
844-
del output_name
845852

846853
# Map to color spaces / numpy
847854
if is_raw:
848855
output_image = output_image.cpu().numpy()
856+
elif output_name == "rgba":
857+
output_image = output_image.detach().cpu().numpy()
849858
elif is_depth:
850859
output_image = (
851860
colormaps.apply_depth_colormap(

0 commit comments

Comments
 (0)