Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 31d857b

Browse files
authored
Merge pull request #128 from Sarasra/main
Add robustnerf mask to mipnerf360
2 parents 5d4c828 + 025d7d7 commit 31d857b

File tree

4 files changed

+141
-4
lines changed

4 files changed

+141
-4
lines changed

internal/configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ class Config:
9191
interlevel_loss_mult: float = 1.0 # Mult. for the loss on the proposal MLP.
9292
orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss.
9393
orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights.
94+
# RobustNerf loss hyperparameters
95+
robustnerf_inlier_quantile = 0.5
96+
enable_robutnerf_loss = False
97+
robustnerf_inner_patch_size = 8
98+
robustnerf_smoothed_filter_size = 3
99+
robustnerf_smoothed_inlier_quantile = 0.5
100+
robustnerf_inner_patch_inlier_quantile = 0.5
94101
# What that loss is imposed on, options are 'normals' or 'normals_pred'.
95102
orientation_loss_target: str = 'normals_pred'
96103
predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss.

internal/robustnerf.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Computes RobustNeRF mask."""
2+
from typing import Mapping, Tuple
3+
4+
from jax import lax
5+
import jax.numpy as jnp
6+
7+
8+
def robustnerf_mask(
9+
errors: jnp.ndarray, loss_threshold: float, config: {str: float}
10+
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
11+
"""Computes RobustNeRF mask.
12+
13+
Args:
14+
errors: f32[n,h,w,c]. Per-subpixel errors in a batch of patches.
15+
loss_threshold: f32[]. Upper bound on per-pixel loss to use to determine
16+
if a pixel is an inlier or not.
17+
config: Config object. A dictionary of hyperparameters.
18+
19+
Returns:
20+
mask: f32[n,h,w,c or 1]. Binary mask that broadcasts to shape [n,h,w,c].
21+
stats: { str: f32[] }. Statistics to pass on.
22+
"""
23+
epsilon = 1e-3
24+
error_dtype = errors.dtype
25+
error_per_pixel = jnp.mean(errors, axis=-1, keepdims=True) # f32[n,h,w,1]
26+
next_loss_threshold = jnp.quantile(
27+
error_per_pixel, config.robustnerf_inlier_quantile
28+
)
29+
mask = jnp.ones_like(error_per_pixel, dtype=error_dtype)
30+
stats = {
31+
'loss_threshold': next_loss_threshold,
32+
}
33+
if config.enable_robustnerf_loss:
34+
assert (
35+
config.robustnerf_inner_patch_size <= config.patch_size
36+
), 'patch_size must be larger than robustnerf_inner_patch_size.'
37+
38+
# Inlier pixels have a value of 1.0 in the mask.
39+
is_inlier_pixel = (error_per_pixel < loss_threshold).astype(error_dtype)
40+
stats['is_inlier_loss'] = jnp.mean(is_inlier_pixel)
41+
42+
# Apply fxf (3x3) box filter 'window' for smoothing (diffusion).
43+
f = config.robustnerf_smoothed_filter_size
44+
window = jnp.ones((1, 1, f, f)) / (f * f)
45+
has_inlier_neighbors = lax.conv(
46+
jnp.transpose(is_inlier_pixel, [0, 3, 1, 2]), window, (1, 1), 'SAME'
47+
)
48+
has_inlier_neighbors = jnp.transpose(has_inlier_neighbors, [0, 2, 3, 1])
49+
50+
# Binarize after smoothing.
51+
# config.robustnerf_smoothed_inlier_quantile default is 0.5 which means at
52+
# least 50% of neighbouring pixels are inliers.
53+
has_inlier_neighbors = (
54+
has_inlier_neighbors > 1 - config.robustnerf_smoothed_inlier_quantile
55+
).astype(error_dtype)
56+
stats['has_inlier_neighbors'] = jnp.mean(has_inlier_neighbors)
57+
is_inlier_pixel = (
58+
has_inlier_neighbors + is_inlier_pixel > epsilon
59+
).astype(error_dtype)
60+
# Construct binary mask for inner pixels. The entire inner patch is either
61+
# active or inactive.
62+
# patch_size is the input patch (h,w), inner patch size can be any value
63+
# smaller than patch_size. Default is for the inner patch size to be half
64+
# the input patch size (i.e. 16x16 -> 8x8).
65+
inner_patch_mask = _robustnerf_inner_patch_mask(
66+
config.robustnerf_inner_patch_size, config.patch_size
67+
)
68+
is_inlier_patch = jnp.mean(
69+
is_inlier_pixel, axis=[1, 2], keepdims=True
70+
) # f32[n,1,1,1]
71+
# robustnerf_inner_patch_inlier_quantile what percentage of the patch
72+
# should be inliers so that the patch is counted as an inlier patch.
73+
is_inlier_patch = (
74+
is_inlier_patch > 1 - config.robustnerf_inner_patch_inlier_quantile
75+
).astype(error_dtype)
76+
is_inlier_patch = is_inlier_patch * inner_patch_mask
77+
stats['is_inlier_patch'] = jnp.mean(is_inlier_patch)
78+
79+
# A pixel is an inlier if it is an inlier according to any of the above
80+
# criteria.
81+
mask = (
82+
is_inlier_patch + is_inlier_pixel > epsilon
83+
).astype(error_dtype)
84+
85+
stats['mask'] = jnp.mean(mask)
86+
return mask, stats
87+
88+
89+
def _robustnerf_inner_patch_mask(
90+
inner_patch_size, outer_patch_size, *, dtype=jnp.float32
91+
):
92+
"""Constructs binary mask for inner patch.
93+
94+
Args:
95+
inner_patch_size: Size of the (square) inside patch.
96+
outer_patch_size: Size of the (square) outer patch.
97+
dtype: dtype for result
98+
99+
Returns:
100+
Binary mask of shape (1, outer_patch_size, outer_patch_size, 1). Mask is
101+
1.0 for the center (inner_patch_size, inner_patch_size) square and 0.0
102+
elsewhere.
103+
"""
104+
pad_size_lower = (outer_patch_size - inner_patch_size) // 2
105+
pad_size_upper = outer_patch_size - (inner_patch_size + pad_size_lower)
106+
mask = jnp.pad(
107+
jnp.ones((1, inner_patch_size, inner_patch_size, 1), dtype=dtype),
108+
(
109+
(0, 0), # batch
110+
(pad_size_lower, pad_size_upper), # height
111+
(pad_size_lower, pad_size_upper), # width
112+
(0, 0), # channels
113+
),
114+
)
115+
return mask
116+
117+

internal/train_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from internal import math
2828
from internal import models
2929
from internal import ref_utils
30+
from internal import robustnerf
3031
from internal import stepfun
3132
from internal import utils
3233
import jax
@@ -68,7 +69,7 @@ def summarize_tree(tree, fn, ancestry=(), max_depth=3):
6869
return stats
6970

7071

71-
def compute_data_loss(batch, renderings, rays, config):
72+
def compute_data_loss(batch, renderings, rays, loss_threshold, config):
7273
"""Computes data loss terms for RGB, normal, and depth outputs."""
7374
data_losses = []
7475
stats = collections.defaultdict(lambda: [])
@@ -100,6 +101,11 @@ def compute_data_loss(batch, renderings, rays, config):
100101
scaling_grad = 1. / (1e-3 + jax.lax.stop_gradient(rgb_render_clip))
101102
# Reweighted L2 loss.
102103
data_loss = resid_sq_clip * scaling_grad**2
104+
elif config.data_loss_type == 'robustnerf':
105+
mask, robust_stats = robustnerf.robustnerf_mask(resid_sq, loss_threshold,
106+
config)
107+
data_loss = data_loss * mask
108+
stats.update(robust_stats)
103109
else:
104110
assert False
105111
data_losses.append((lossmult * data_loss).sum() / denom)
@@ -236,6 +242,7 @@ def train_step(
236242
batch,
237243
cameras,
238244
train_frac,
245+
loss_threshold,
239246
):
240247
"""One optimization step.
241248
@@ -245,6 +252,7 @@ def train_step(
245252
batch: dict, a mini-batch of data for training.
246253
cameras: module containing camera poses.
247254
train_frac: float, the fraction of training that is complete.
255+
loss_threshold: float, the loss threshold for inliers (for robustness).
248256
249257
Returns:
250258
A tuple (new_state, stats, rng) with
@@ -273,7 +281,8 @@ def loss_fn(variables):
273281

274282
losses = {}
275283

276-
data_loss, stats = compute_data_loss(batch, renderings, rays, config)
284+
data_loss, stats = compute_data_loss(batch, renderings, rays,
285+
loss_threshold, config)
277286
losses['data'] = data_loss
278287

279288
if config.interlevel_loss_mult > 0:
@@ -332,7 +341,7 @@ def loss_fn(variables):
332341
train_pstep = jax.pmap(
333342
train_step,
334343
axis_name='batch',
335-
in_axes=(0, 0, 0, None, None),
344+
in_axes=(0, 0, 0, None, None, None),
336345
donate_argnums=(0, 1))
337346
return train_pstep
338347

@@ -394,7 +403,8 @@ def setup_model(
394403
) -> Tuple[models.Model, TrainState, Callable[
395404
[FrozenVariableDict, jnp.array, utils.Rays],
396405
MutableMapping[Text, Any]], Callable[
397-
[jnp.array, TrainState, utils.Batch, Optional[Tuple[Any, ...]], float],
406+
[jnp.array, TrainState, utils.Batch,
407+
Optional[Tuple[Any, ...]], float, float],
398408
Tuple[TrainState, Dict[Text, Any], jnp.array]], Callable[[int], float]]:
399409
"""Creates NeRF model, optimizer, and pmap-ed train/render functions."""
400410

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def main(unused_argv):
106106
num_steps = config.early_exit_steps
107107
else:
108108
num_steps = config.max_steps
109+
loss_threshold = 1.0
109110
for step, batch in zip(range(init_step, num_steps + 1), pdataset):
110111

111112
if reset_stats and (jax.host_id() == 0):
@@ -122,7 +123,9 @@ def main(unused_argv):
122123
batch,
123124
cameras,
124125
train_frac,
126+
loss_threshold,
125127
)
128+
loss_threshold = jnp.mean(stats['loss_threshold'])
126129

127130
if step % config.gc_every == 0:
128131
gc.collect() # Disable automatic garbage collection for efficiency.

0 commit comments

Comments
 (0)