Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Add support for multi-camera datasets #145

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions internal/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,25 +573,51 @@ def pix_to_dir(x, y):
# Apply inverse intrinsic matrices.
camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked)

if distortion_params is not None:
mask = camtype > 0
if xnp.any(mask):
is_uniform = xnp.all(mask)
if is_uniform:
ldistortion_params = distortion_params
dl = camera_dirs_stacked
else:
ldistortion_params = distortion_params[mask, :]
dl = camera_dirs_stacked[:, mask, :]

# Correct for distortion.
dist_dict = dict(zip(
["k1", "k2", "k3", "k4", "p1", "p2"],
xnp.moveaxis(ldistortion_params, -1, 0)))
x, y = _radial_and_tangential_undistort(
camera_dirs_stacked[..., 0],
camera_dirs_stacked[..., 1],
**distortion_params,
xnp=xnp)
camera_dirs_stacked = xnp.stack([x, y, xnp.ones_like(x)], -1)

if camtype == ProjectionType.FISHEYE:
theta = xnp.sqrt(xnp.sum(xnp.square(camera_dirs_stacked[..., :2]), axis=-1))
theta = xnp.minimum(xnp.pi, theta)

sin_theta_over_theta = xnp.sin(theta) / theta
camera_dirs_stacked = xnp.stack([
camera_dirs_stacked[..., 0] * sin_theta_over_theta,
camera_dirs_stacked[..., 1] * sin_theta_over_theta,
xnp.cos(theta),
], axis=-1)
dl[..., 0],
dl[..., 1],
**dist_dict,
xnp=xnp)
dl = xnp.stack([x, y, xnp.ones_like(x)], -1)
dcamera_types = camtype[mask]

fisheye_mask = dcamera_types == 2
if fisheye_mask.any():
is_all_fisheye = xnp.all(fisheye_mask)
if is_all_fisheye:
dll = dl
else:
dll = dl[:, mask, :2]
theta = xnp.sqrt(xnp.sum(xnp.square(dll[..., :2]), axis=-1, keepdims=True))
theta = xnp.minimum(xnp.pi, theta)
sin_theta_over_theta = xnp.sin(theta) / theta

if is_all_fisheye:
dl[..., :2] *= sin_theta_over_theta
dl[..., 2:] *= xnp.cos(theta)
else:
dl[:, mask, :2] *= sin_theta_over_theta
dl[:, mask, 2:] *= xnp.cos(theta)

if mask.any():
if is_uniform:
camera_dirs_stacked = dl
else:
camera_dirs_stacked[:, mask, :] = dl

# Flip from OpenCV to OpenGL coordinate system.
camera_dirs_stacked = matmul(camera_dirs_stacked,
Expand Down Expand Up @@ -655,21 +681,30 @@ def cast_ray_batch(
Returns:
rays: Rays dataclass with computed 3D world space ray data.
"""
pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras
del camtype
pixtocams, camtoworlds, distortion_params, pixtocam_ndc, camtype = cameras

# pixels.cam_idx has shape [..., 1], remove this hanging dimension.
cam_idx = pixels.cam_idx[..., 0]
batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx]

bs = pixels.pix_x_int.shape
dtype = pixtocams.dtype
origins = xnp.zeros((*bs, 3), dtype=dtype)
directions = xnp.zeros((*bs, 3), dtype=dtype)
viewdirs = xnp.zeros((*bs, 3), dtype=dtype)
radii = xnp.zeros((*bs, 1), dtype=dtype)
imageplane = xnp.zeros((*bs, 1), dtype=dtype)

# Compute rays from pixel coordinates.
origins, directions, viewdirs, radii, imageplane = pixels_to_rays(
pixels.pix_x_int,
pixels.pix_y_int,
batch_index(pixtocams),
batch_index(camtoworlds),
distortion_params=distortion_params,
pixtocam_ndc=pixtocam_ndc,
camtype=camtype,
distortion_params=distortion_params[cam_idx],
pixtocam_ndc=pixtocam_ndc[cam_idx] if pixtocam_ndc is not None else None,
camtype=camtype[cam_idx],
xnp=xnp)

# Create Rays data structure.
Expand Down
177 changes: 119 additions & 58 deletions internal/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,69 @@ def process(
# self.load_points3D() # For now, we do not need the point cloud data.

# Assume shared intrinsics between all cameras.
cam = self.cameras[1]
cams = {}
for cam_id, cam in self.cameras.items():
# Extract focal lengths and principal point parameters.
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))

# Get distortion parameters.
type_ = cam.camera_type

if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
params = None
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 1 or type_ == 'PINHOLE':
params = None
camtype = camera_utils.ProjectionType.PERSPECTIVE

if type_ == 2 or type_ == 'SIMPLE_RADIAL':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 3 or type_ == 'RADIAL':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
params['k2'] = cam.k2
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 4 or type_ == 'OPENCV':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
params['k2'] = cam.k2
params['p1'] = cam.p1
params['p2'] = cam.p2
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
params['k1'] = cam.k1
params['k2'] = cam.k2
params['k3'] = cam.k3
params['k4'] = cam.k4
camtype = camera_utils.ProjectionType.FISHEYE
cams[cam_id] = (cam, pixtocam, params, camtype)

# Extract focal lengths and principal point parameters.
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))

# Extract extrinsic matrices in world-to-camera format.
imdata = self.images
w2c_mats = []
pixtocams = []
all_params = []
all_camtypes = []
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
for k in imdata:
im = imdata[k]
rot = im.R()
trans = im.tvec.reshape(3, 1)
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
w2c_mats.append(w2c)
cam, pixtocam, params, camtype = cams[im.camera_id]
all_params.append(params)
all_camtypes.append(camtype)
pixtocams.append(pixtocam)
w2c_mats = np.stack(w2c_mats, axis=0)

# Convert extrinsics to camera-to-world.
Expand All @@ -108,45 +155,9 @@ def process(
# Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.
poses = poses @ np.diag([1, -1, -1, 1])

# Get distortion parameters.
type_ = cam.camera_type

if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
params = None
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 1 or type_ == 'PINHOLE':
params = None
camtype = camera_utils.ProjectionType.PERSPECTIVE

if type_ == 2 or type_ == 'SIMPLE_RADIAL':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 3 or type_ == 'RADIAL':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
params['k2'] = cam.k2
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 4 or type_ == 'OPENCV':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
params['k1'] = cam.k1
params['k2'] = cam.k2
params['p1'] = cam.p1
params['p2'] = cam.p2
camtype = camera_utils.ProjectionType.PERSPECTIVE

elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
params['k1'] = cam.k1
params['k2'] = cam.k2
params['k3'] = cam.k3
params['k4'] = cam.k4
camtype = camera_utils.ProjectionType.FISHEYE
pixtocams = np.stack(pixtocams)

return names, poses, pixtocam, params, camtype
return names, poses, pixtocams, all_params, all_camtypes


def load_blender_posedata(data_dir, split=None):
Expand Down Expand Up @@ -288,8 +299,9 @@ def __init__(self,
self.images: np.ndarray = None
self.camtoworlds: np.ndarray = None
self.pixtocams: np.ndarray = None
self.height: int = None
self.width: int = None
self.height: np.ndarray = None
self.width: np.ndarray = None


# Load data from disk using provided config parameters.
self._load_renderings(config)
Expand All @@ -314,11 +326,45 @@ def __init__(self,
self.height)

self._n_examples = self.camtoworlds.shape[0]
if len(self.pixtocams.shape) == 2:
self.pixtocams = np.repeat(self.pixtocams[None], self._n_examples, 0)
if not isinstance(self.focal, np.ndarray):
self.focal = np.full((self._n_examples,), self.focal, dtype=np.float32)
if not isinstance(self.near, np.ndarray):
self.near = np.full((self._n_examples,), self.near, dtype=np.float32)
if not isinstance(self.height, np.ndarray):
self.far = np.full((self._n_examples,), self.far, dtype=np.float32)
if not isinstance(self.width, np.ndarray):
self.width = np.full((self._n_examples,), self.width, dtype=np.int32)
if not isinstance(self.height, np.ndarray):
self.height = np.full((self._n_examples,), self.height, dtype=np.int32)
if not isinstance(self.camtype, list):
self.camtype = [self.camtype] * self._n_examples
map_camera = {
camera_utils.ProjectionType.PERSPECTIVE.value: 1,
camera_utils.ProjectionType.FISHEYE.value: 2
}
self.camtype = np.array([map_camera[x.value] for x in self.camtype], dtype=np.int32)
distortion_params = np.zeros((self._n_examples, 6), dtype=np.float32)
for i in range(self._n_examples):
k = self.distortion_params
if isinstance(self.distortion_params, list):
try:
k = self.distortion_params[i]
except Exception as e:
breakpoint()
print(e)
if k is None:
self.camtype[i] = 0
continue
distortion_params[i] = np.array([k[m] for m in ["k1", "k2", "k3", "k4", "p1", "p2"]], dtype=np.float32)
self.distortion_params = distortion_params

self.cameras = (self.pixtocams,
self.camtoworlds,
self.distortion_params,
self.pixtocam_ndc)
self.pixtocam_ndc,
self.camtype)

# Seed the queue with one batch to avoid race condition.
if self.split == utils.DataSplit.TRAIN:
Expand Down Expand Up @@ -408,10 +454,11 @@ def _make_ray_batch(self,
"""

broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
idx = 0 if self.render_path else cam_idx
ray_kwargs = {
'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult,
'near': broadcast_scalar(self.near),
'far': broadcast_scalar(self.far),
'near': broadcast_scalar(self.near[idx]),
'far': broadcast_scalar(self.far[idx]),
'cam_idx': broadcast_scalar(cam_idx),
}
# Collect per-camera information needed for each ray.
Expand Down Expand Up @@ -456,23 +503,25 @@ def _next_train(self) -> utils.Batch:
num_patches = self._batch_size // self._patch_size ** 2
lower_border = self._num_border_pixels_to_mask
upper_border = self._num_border_pixels_to_mask + self._patch_size - 1

# Random camera indices.
if self._batching == utils.BatchingMethod.ALL_IMAGES:
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
else:
cam_idx = np.random.randint(0, self._n_examples, (1,))

# Random pixel patch x-coordinates.
pix_x_int = np.random.randint(lower_border, self.width - upper_border,
pix_x_int = np.random.randint(lower_border, self.width[cam_idx] - upper_border,
(num_patches, 1, 1))
# Random pixel patch y-coordinates.
pix_y_int = np.random.randint(lower_border, self.height - upper_border,
pix_y_int = np.random.randint(lower_border, self.height[cam_idx] - upper_border,
(num_patches, 1, 1))
# Add patch coordinate offsets.
# Shape will broadcast to (num_patches, _patch_size, _patch_size).
patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates(
self._patch_size, self._patch_size)
pix_x_int = pix_x_int + patch_dx_int
pix_y_int = pix_y_int + patch_dy_int
# Random camera indices.
if self._batching == utils.BatchingMethod.ALL_IMAGES:
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
else:
cam_idx = np.random.randint(0, self._n_examples, (1,))

if self._apply_bayer_mask:
# Compute the Bayer mosaic mask for each pixel in the batch.
Expand All @@ -488,12 +537,16 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch:
if self._render_spherical:
camtoworld = self.camtoworlds[cam_idx]
rays = camera_utils.cast_spherical_rays(
camtoworld, self.height, self.width, self.near, self.far, xnp=np)
camtoworld,
self.height[cam_idx], self.width[cam_idx],
self.near[cam_idx],
self.far[cam_idx],
xnp=np)
return utils.Batch(rays=rays)
else:
# Generate rays for all pixels in the image.
pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
self.width, self.height)
self.width[cam_idx], self.height[cam_idx])
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx)

def _next_test(self) -> utils.Batch:
Expand Down Expand Up @@ -593,13 +646,15 @@ def _load_renderings(self, config):
inds = np.argsort(image_names)
image_names = [image_names[i] for i in inds]
poses = poses[inds]
pixtocam = pixtocam[inds]
distortion_params = [distortion_params[i] for i in inds]
camtype = [camtype[i] for i in inds]

# Scale the inverse intrinsics matrix by the image downsampling factor.
pixtocam = pixtocam @ np.diag([factor, factor, 1.])
self.pixtocams = pixtocam.astype(np.float32)
self.focal = 1. / self.pixtocams[0, 0]
self.focal = 1. / self.pixtocams[..., 0, 0]
self.distortion_params = distortion_params
self.camtype = camtype

raw_testscene = False
if config.rawnerf_mode:
Expand Down Expand Up @@ -706,6 +761,12 @@ def _load_renderings(self, config):
# All per-image quantities must be re-indexed using the split indices.
images = images[indices]
poses = poses[indices]
self.pixtocams = self.pixtocams[indices]
self.focal = self.focal[indices]
list_indices = np.where(indices)[0] if indices.dtype == np.bool_ else indices
self.distortion_params = [self.distortion_params[i] for i in list_indices]
self.camtype = [camtype[i] for i in list_indices]
assert len(self.camtype) == len(self.pixtocams)
if self.exposures is not None:
self.exposures = self.exposures[indices]
if config.rawnerf_mode:
Expand Down