|
| 1 | +"""Computes RobustNeRF mask.""" |
| 2 | +from typing import Mapping, Tuple |
| 3 | + |
| 4 | +from jax import lax |
| 5 | +import jax.numpy as jnp |
| 6 | + |
| 7 | + |
| 8 | +def robustnerf_mask( |
| 9 | + errors: jnp.ndarray, loss_threshold: float, config: {str: float} |
| 10 | +) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: |
| 11 | + """Computes RobustNeRF mask. |
| 12 | +
|
| 13 | + Args: |
| 14 | + errors: f32[n,h,w,c]. Per-subpixel errors in a batch of patches. |
| 15 | + loss_threshold: f32[]. Upper bound on per-pixel loss to use to determine |
| 16 | + if a pixel is an inlier or not. |
| 17 | + config: Config object. A dictionary of hyperparameters. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + mask: f32[n,h,w,c or 1]. Binary mask that broadcasts to shape [n,h,w,c]. |
| 21 | + stats: { str: f32[] }. Statistics to pass on. |
| 22 | + """ |
| 23 | + epsilon = 1e-3 |
| 24 | + error_dtype = errors.dtype |
| 25 | + error_per_pixel = jnp.mean(errors, axis=-1, keepdims=True) # f32[n,h,w,1] |
| 26 | + next_loss_threshold = jnp.quantile( |
| 27 | + error_per_pixel, config.robustnerf_inlier_quantile |
| 28 | + ) |
| 29 | + mask = jnp.ones_like(error_per_pixel, dtype=error_dtype) |
| 30 | + stats = { |
| 31 | + 'loss_threshold': next_loss_threshold, |
| 32 | + } |
| 33 | + if config.enable_robustnerf_loss: |
| 34 | + assert ( |
| 35 | + config.robustnerf_inner_patch_size <= config.patch_size |
| 36 | + ), 'patch_size must be larger than robustnerf_inner_patch_size.' |
| 37 | + |
| 38 | + # Inlier pixels have a value of 1.0 in the mask. |
| 39 | + is_inlier_pixel = (error_per_pixel < loss_threshold).astype(error_dtype) |
| 40 | + stats['is_inlier_loss'] = jnp.mean(is_inlier_pixel) |
| 41 | + |
| 42 | + # Apply fxf (3x3) box filter 'window' for smoothing (diffusion). |
| 43 | + f = config.robustnerf_smoothed_filter_size |
| 44 | + window = jnp.ones((1, 1, f, f)) / (f * f) |
| 45 | + has_inlier_neighbors = lax.conv( |
| 46 | + jnp.transpose(is_inlier_pixel, [0, 3, 1, 2]), window, (1, 1), 'SAME' |
| 47 | + ) |
| 48 | + has_inlier_neighbors = jnp.transpose(has_inlier_neighbors, [0, 2, 3, 1]) |
| 49 | + |
| 50 | + # Binarize after smoothing. |
| 51 | + # config.robustnerf_smoothed_inlier_quantile default is 0.5 which means at |
| 52 | + # least 50% of neighbouring pixels are inliers. |
| 53 | + has_inlier_neighbors = ( |
| 54 | + has_inlier_neighbors > 1 - config.robustnerf_smoothed_inlier_quantile |
| 55 | + ).astype(error_dtype) |
| 56 | + stats['has_inlier_neighbors'] = jnp.mean(has_inlier_neighbors) |
| 57 | + is_inlier_pixel = ( |
| 58 | + has_inlier_neighbors + is_inlier_pixel > epsilon |
| 59 | + ).astype(error_dtype) |
| 60 | + # Construct binary mask for inner pixels. The entire inner patch is either |
| 61 | + # active or inactive. |
| 62 | + # patch_size is the input patch (h,w), inner patch size can be any value |
| 63 | + # smaller than patch_size. Default is for the inner patch size to be half |
| 64 | + # the input patch size (i.e. 16x16 -> 8x8). |
| 65 | + inner_patch_mask = _robustnerf_inner_patch_mask( |
| 66 | + config.robustnerf_inner_patch_size, config.patch_size |
| 67 | + ) |
| 68 | + is_inlier_patch = jnp.mean( |
| 69 | + is_inlier_pixel, axis=[1, 2], keepdims=True |
| 70 | + ) # f32[n,1,1,1] |
| 71 | + # robustnerf_inner_patch_inlier_quantile what percentage of the patch |
| 72 | + # should be inliers so that the patch is counted as an inlier patch. |
| 73 | + is_inlier_patch = ( |
| 74 | + is_inlier_patch > 1 - config.robustnerf_inner_patch_inlier_quantile |
| 75 | + ).astype(error_dtype) |
| 76 | + is_inlier_patch = is_inlier_patch * inner_patch_mask |
| 77 | + stats['is_inlier_patch'] = jnp.mean(is_inlier_patch) |
| 78 | + |
| 79 | + # A pixel is an inlier if it is an inlier according to any of the above |
| 80 | + # criteria. |
| 81 | + mask = ( |
| 82 | + is_inlier_patch + is_inlier_pixel > epsilon |
| 83 | + ).astype(error_dtype) |
| 84 | + |
| 85 | + stats['mask'] = jnp.mean(mask) |
| 86 | + return mask, stats |
| 87 | + |
| 88 | + |
| 89 | +def _robustnerf_inner_patch_mask( |
| 90 | + inner_patch_size, outer_patch_size, *, dtype=jnp.float32 |
| 91 | +): |
| 92 | + """Constructs binary mask for inner patch. |
| 93 | +
|
| 94 | + Args: |
| 95 | + inner_patch_size: Size of the (square) inside patch. |
| 96 | + outer_patch_size: Size of the (square) outer patch. |
| 97 | + dtype: dtype for result |
| 98 | +
|
| 99 | + Returns: |
| 100 | + Binary mask of shape (1, outer_patch_size, outer_patch_size, 1). Mask is |
| 101 | + 1.0 for the center (inner_patch_size, inner_patch_size) square and 0.0 |
| 102 | + elsewhere. |
| 103 | + """ |
| 104 | + pad_size_lower = (outer_patch_size - inner_patch_size) // 2 |
| 105 | + pad_size_upper = outer_patch_size - (inner_patch_size + pad_size_lower) |
| 106 | + mask = jnp.pad( |
| 107 | + jnp.ones((1, inner_patch_size, inner_patch_size, 1), dtype=dtype), |
| 108 | + ( |
| 109 | + (0, 0), # batch |
| 110 | + (pad_size_lower, pad_size_upper), # height |
| 111 | + (pad_size_lower, pad_size_upper), # width |
| 112 | + (0, 0), # channels |
| 113 | + ), |
| 114 | + ) |
| 115 | + return mask |
| 116 | + |
| 117 | + |
0 commit comments