Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 5b4d4f6

Browse files
authored
Merge pull request #133 from Sarasra/main
update robustnerf into a trainable state with a robustnerf config.
2 parents 31d857b + 00691de commit 5b4d4f6

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

configs/360_robustnerf.gin

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
Config.dataset_loader = 'llff'
2+
Config.near = 0.2
3+
Config.far = 1e6
4+
Config.factor = 4
5+
6+
Config.patch_size = 16
7+
Config.data_loss_type = 'robustnerf'
8+
Config.robustnerf_inlier_quantile = 0.8
9+
Config.enable_robustnerf_loss = True
10+
11+
Model.raydist_fn = @jnp.reciprocal
12+
Model.opaque_background = True
13+
14+
PropMLP.warp_fn = @coord.contract
15+
PropMLP.net_depth = 4
16+
PropMLP.net_width = 256
17+
PropMLP.disable_density_normals = True
18+
PropMLP.disable_rgb = True
19+
20+
NerfMLP.warp_fn = @coord.contract
21+
NerfMLP.net_depth = 8
22+
NerfMLP.net_width = 1024
23+
NerfMLP.disable_density_normals = True
24+

internal/configs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ class Config:
9292
orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss.
9393
orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights.
9494
# RobustNerf loss hyperparameters
95-
robustnerf_inlier_quantile = 0.5
96-
enable_robutnerf_loss = False
97-
robustnerf_inner_patch_size = 8
98-
robustnerf_smoothed_filter_size = 3
99-
robustnerf_smoothed_inlier_quantile = 0.5
100-
robustnerf_inner_patch_inlier_quantile = 0.5
95+
robustnerf_inlier_quantile: float = 0.5
96+
enable_robustnerf_loss: bool = False
97+
robustnerf_inner_patch_size: int = 8
98+
robustnerf_smoothed_filter_size: int = 3
99+
robustnerf_smoothed_inlier_quantile: float = 0.5
100+
robustnerf_inner_patch_inlier_quantile: float = 0.5
101101
# What that loss is imposed on, options are 'normals' or 'normals_pred'.
102102
orientation_loss_target: str = 'normals_pred'
103103
predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss.

internal/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def compute_data_loss(batch, renderings, rays, loss_threshold, config):
104104
elif config.data_loss_type == 'robustnerf':
105105
mask, robust_stats = robustnerf.robustnerf_mask(resid_sq, loss_threshold,
106106
config)
107-
data_loss = data_loss * mask
107+
data_loss = resid_sq * mask
108108
stats.update(robust_stats)
109109
else:
110110
assert False

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def main(unused_argv):
125125
train_frac,
126126
loss_threshold,
127127
)
128-
loss_threshold = jnp.mean(stats['loss_threshold'])
128+
if config.enable_robustnerf_loss:
129+
loss_threshold = jnp.mean(stats['loss_threshold'])
129130

130131
if step % config.gc_every == 0:
131132
gc.collect() # Disable automatic garbage collection for efficiency.

0 commit comments

Comments
 (0)