Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 99125d6

Browse files
committed
Add support for multi-camera datasets
1 parent 5b4d4f6 commit 99125d6

File tree

2 files changed

+164
-77
lines changed

2 files changed

+164
-77
lines changed

internal/camera_utils.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -573,25 +573,51 @@ def pix_to_dir(x, y):
573573
# Apply inverse intrinsic matrices.
574574
camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked)
575575

576-
if distortion_params is not None:
576+
mask = camtype > 0
577+
if xnp.any(mask):
578+
is_uniform = xnp.all(mask)
579+
if is_uniform:
580+
ldistortion_params = distortion_params
581+
dl = camera_dirs_stacked
582+
else:
583+
ldistortion_params = distortion_params[mask, :]
584+
dl = camera_dirs_stacked[:, mask, :]
585+
577586
# Correct for distortion.
587+
dist_dict = dict(zip(
588+
["k1", "k2", "k3", "k4", "p1", "p2"],
589+
xnp.moveaxis(ldistortion_params, -1, 0)))
578590
x, y = _radial_and_tangential_undistort(
579-
camera_dirs_stacked[..., 0],
580-
camera_dirs_stacked[..., 1],
581-
**distortion_params,
582-
xnp=xnp)
583-
camera_dirs_stacked = xnp.stack([x, y, xnp.ones_like(x)], -1)
584-
585-
if camtype == ProjectionType.FISHEYE:
586-
theta = xnp.sqrt(xnp.sum(xnp.square(camera_dirs_stacked[..., :2]), axis=-1))
587-
theta = xnp.minimum(xnp.pi, theta)
588-
589-
sin_theta_over_theta = xnp.sin(theta) / theta
590-
camera_dirs_stacked = xnp.stack([
591-
camera_dirs_stacked[..., 0] * sin_theta_over_theta,
592-
camera_dirs_stacked[..., 1] * sin_theta_over_theta,
593-
xnp.cos(theta),
594-
], axis=-1)
591+
dl[..., 0],
592+
dl[..., 1],
593+
**dist_dict,
594+
xnp=xnp)
595+
dl = xnp.stack([x, y, xnp.ones_like(x)], -1)
596+
dcamera_types = camtype[mask]
597+
598+
fisheye_mask = dcamera_types == 2
599+
if fisheye_mask.any():
600+
is_all_fisheye = xnp.all(fisheye_mask)
601+
if is_all_fisheye:
602+
dll = dl
603+
else:
604+
dll = dl[:, mask, :2]
605+
theta = xnp.sqrt(xnp.sum(xnp.square(dll[..., :2]), axis=-1))
606+
theta = xnp.minimum(xnp.pi, theta)
607+
sin_theta_over_theta = xnp.sin(theta) / theta
608+
609+
if is_all_fisheye:
610+
dl[..., :2] *= sin_theta_over_theta
611+
dl[..., 2:] *= xnp.cos(theta)
612+
else:
613+
dl[:, mask, :2] *= sin_theta_over_theta
614+
dl[:, mask, 2:] *= xnp.cos(theta)
615+
616+
if mask.any():
617+
if is_uniform:
618+
camera_dirs_stacked = dl
619+
else:
620+
camera_dirs_stacked[:, mask, :] = dl
595621

596622
# Flip from OpenCV to OpenGL coordinate system.
597623
camera_dirs_stacked = matmul(camera_dirs_stacked,
@@ -655,21 +681,30 @@ def cast_ray_batch(
655681
Returns:
656682
rays: Rays dataclass with computed 3D world space ray data.
657683
"""
658-
pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras
684+
del camtype
685+
pixtocams, camtoworlds, distortion_params, pixtocam_ndc, camtype = cameras
659686

660687
# pixels.cam_idx has shape [..., 1], remove this hanging dimension.
661688
cam_idx = pixels.cam_idx[..., 0]
662689
batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx]
663690

691+
bs = pixels.pix_x_int.shape
692+
dtype = pixtocams.dtype
693+
origins = xnp.zeros((*bs, 3), dtype=dtype)
694+
directions = xnp.zeros((*bs, 3), dtype=dtype)
695+
viewdirs = xnp.zeros((*bs, 3), dtype=dtype)
696+
radii = xnp.zeros((*bs, 1), dtype=dtype)
697+
imageplane = xnp.zeros((*bs, 1), dtype=dtype)
698+
664699
# Compute rays from pixel coordinates.
665700
origins, directions, viewdirs, radii, imageplane = pixels_to_rays(
666701
pixels.pix_x_int,
667702
pixels.pix_y_int,
668703
batch_index(pixtocams),
669704
batch_index(camtoworlds),
670-
distortion_params=distortion_params,
671-
pixtocam_ndc=pixtocam_ndc,
672-
camtype=camtype,
705+
distortion_params=distortion_params[cam_idx],
706+
pixtocam_ndc=pixtocam_ndc[cam_idx] if pixtocam_ndc is not None else None,
707+
camtype=camtype[cam_idx],
673708
xnp=xnp)
674709

675710
# Create Rays data structure.

internal/datasets.py

Lines changed: 108 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,69 @@ def process(
7979
# self.load_points3D() # For now, we do not need the point cloud data.
8080

8181
# Assume shared intrinsics between all cameras.
82-
cam = self.cameras[1]
82+
cams = {}
83+
for cam_id, cam in self.cameras.items():
84+
# Extract focal lengths and principal point parameters.
85+
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
86+
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
87+
88+
# Get distortion parameters.
89+
type_ = cam.camera_type
90+
91+
if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
92+
params = None
93+
camtype = camera_utils.ProjectionType.PERSPECTIVE
94+
95+
elif type_ == 1 or type_ == 'PINHOLE':
96+
params = None
97+
camtype = camera_utils.ProjectionType.PERSPECTIVE
98+
99+
if type_ == 2 or type_ == 'SIMPLE_RADIAL':
100+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
101+
params['k1'] = cam.k1
102+
camtype = camera_utils.ProjectionType.PERSPECTIVE
103+
104+
elif type_ == 3 or type_ == 'RADIAL':
105+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
106+
params['k1'] = cam.k1
107+
params['k2'] = cam.k2
108+
camtype = camera_utils.ProjectionType.PERSPECTIVE
109+
110+
elif type_ == 4 or type_ == 'OPENCV':
111+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
112+
params['k1'] = cam.k1
113+
params['k2'] = cam.k2
114+
params['p1'] = cam.p1
115+
params['p2'] = cam.p2
116+
camtype = camera_utils.ProjectionType.PERSPECTIVE
117+
118+
elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
119+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
120+
params['k1'] = cam.k1
121+
params['k2'] = cam.k2
122+
params['k3'] = cam.k3
123+
params['k4'] = cam.k4
124+
camtype = camera_utils.ProjectionType.FISHEYE
125+
cams[cam_id] = (cam, pixtocam, params, camtype)
83126

84-
# Extract focal lengths and principal point parameters.
85-
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
86-
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
87127

88128
# Extract extrinsic matrices in world-to-camera format.
89129
imdata = self.images
90130
w2c_mats = []
131+
pixtocams = []
132+
all_params = []
133+
all_camtypes = []
91134
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
92135
for k in imdata:
93136
im = imdata[k]
94137
rot = im.R()
95138
trans = im.tvec.reshape(3, 1)
96139
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
97140
w2c_mats.append(w2c)
141+
cam, pixtocam, params, camtype = cams[im.camera_id]
142+
all_params.append(params)
143+
all_camtypes.append(camtype)
144+
pixtocams.append(pixtocam)
98145
w2c_mats = np.stack(w2c_mats, axis=0)
99146

100147
# Convert extrinsics to camera-to-world.
@@ -108,45 +155,9 @@ def process(
108155
# Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.
109156
poses = poses @ np.diag([1, -1, -1, 1])
110157

111-
# Get distortion parameters.
112-
type_ = cam.camera_type
113-
114-
if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
115-
params = None
116-
camtype = camera_utils.ProjectionType.PERSPECTIVE
117-
118-
elif type_ == 1 or type_ == 'PINHOLE':
119-
params = None
120-
camtype = camera_utils.ProjectionType.PERSPECTIVE
121-
122-
if type_ == 2 or type_ == 'SIMPLE_RADIAL':
123-
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
124-
params['k1'] = cam.k1
125-
camtype = camera_utils.ProjectionType.PERSPECTIVE
126-
127-
elif type_ == 3 or type_ == 'RADIAL':
128-
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
129-
params['k1'] = cam.k1
130-
params['k2'] = cam.k2
131-
camtype = camera_utils.ProjectionType.PERSPECTIVE
132-
133-
elif type_ == 4 or type_ == 'OPENCV':
134-
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
135-
params['k1'] = cam.k1
136-
params['k2'] = cam.k2
137-
params['p1'] = cam.p1
138-
params['p2'] = cam.p2
139-
camtype = camera_utils.ProjectionType.PERSPECTIVE
140-
141-
elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
142-
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
143-
params['k1'] = cam.k1
144-
params['k2'] = cam.k2
145-
params['k3'] = cam.k3
146-
params['k4'] = cam.k4
147-
camtype = camera_utils.ProjectionType.FISHEYE
158+
pixtocams = np.stack(pixtocams)
148159

149-
return names, poses, pixtocam, params, camtype
160+
return names, poses, pixtocams, all_params, all_camtypes
150161

151162

152163
def load_blender_posedata(data_dir, split=None):
@@ -288,8 +299,9 @@ def __init__(self,
288299
self.images: np.ndarray = None
289300
self.camtoworlds: np.ndarray = None
290301
self.pixtocams: np.ndarray = None
291-
self.height: int = None
292-
self.width: int = None
302+
self.height: np.ndarray = None
303+
self.width: np.ndarray = None
304+
293305

294306
# Load data from disk using provided config parameters.
295307
self._load_renderings(config)
@@ -314,11 +326,41 @@ def __init__(self,
314326
self.height)
315327

316328
self._n_examples = self.camtoworlds.shape[0]
329+
if len(self.pixtocams.shape) == 2:
330+
self.pixtocams = np.repeat(self.pixtocams[None], self._n_examples, 0)
331+
if not isinstance(self.focal, np.ndarray):
332+
self.focal = np.full((self._n_examples,), self.focal, dtype=np.int32)
333+
if not isinstance(self.width, np.ndarray):
334+
self.width = np.full((self._n_examples,), self.width, dtype=np.int32)
335+
if not isinstance(self.height, np.ndarray):
336+
self.height = np.full((self._n_examples,), self.height, dtype=np.int32)
337+
if not isinstance(self.camtype, list):
338+
self.camtype = [self.camtype] * self._n_examples
339+
map_camera = {
340+
camera_utils.ProjectionType.PERSPECTIVE.value: 1,
341+
camera_utils.ProjectionType.FISHEYE.value: 2
342+
}
343+
self.camtype = np.array([map_camera[x.value] for x in self.camtype], dtype=np.int32)
344+
distortion_params = np.zeros((self._n_examples, 6), dtype=np.float32)
345+
for i in range(self._n_examples):
346+
k = self.distortion_params
347+
if isinstance(self.distortion_params, list):
348+
try:
349+
k = self.distortion_params[i]
350+
except Exception as e:
351+
breakpoint()
352+
print(e)
353+
if k is None:
354+
self.camtype[i] = 0
355+
continue
356+
distortion_params[i] = np.array([k[m] for m in ["k1", "k2", "k3", "k4", "p1", "p2"]], dtype=np.float32)
357+
self.distortion_params = distortion_params
317358

318359
self.cameras = (self.pixtocams,
319360
self.camtoworlds,
320361
self.distortion_params,
321-
self.pixtocam_ndc)
362+
self.pixtocam_ndc,
363+
self.camtype)
322364

323365
# Seed the queue with one batch to avoid race condition.
324366
if self.split == utils.DataSplit.TRAIN:
@@ -456,23 +498,25 @@ def _next_train(self) -> utils.Batch:
456498
num_patches = self._batch_size // self._patch_size ** 2
457499
lower_border = self._num_border_pixels_to_mask
458500
upper_border = self._num_border_pixels_to_mask + self._patch_size - 1
501+
502+
# Random camera indices.
503+
if self._batching == utils.BatchingMethod.ALL_IMAGES:
504+
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
505+
else:
506+
cam_idx = np.random.randint(0, self._n_examples, (1,))
507+
459508
# Random pixel patch x-coordinates.
460-
pix_x_int = np.random.randint(lower_border, self.width - upper_border,
509+
pix_x_int = np.random.randint(lower_border, self.width[cam_idx] - upper_border,
461510
(num_patches, 1, 1))
462511
# Random pixel patch y-coordinates.
463-
pix_y_int = np.random.randint(lower_border, self.height - upper_border,
512+
pix_y_int = np.random.randint(lower_border, self.height[cam_idx] - upper_border,
464513
(num_patches, 1, 1))
465514
# Add patch coordinate offsets.
466515
# Shape will broadcast to (num_patches, _patch_size, _patch_size).
467516
patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates(
468517
self._patch_size, self._patch_size)
469518
pix_x_int = pix_x_int + patch_dx_int
470519
pix_y_int = pix_y_int + patch_dy_int
471-
# Random camera indices.
472-
if self._batching == utils.BatchingMethod.ALL_IMAGES:
473-
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
474-
else:
475-
cam_idx = np.random.randint(0, self._n_examples, (1,))
476520

477521
if self._apply_bayer_mask:
478522
# Compute the Bayer mosaic mask for each pixel in the batch.
@@ -488,12 +532,12 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch:
488532
if self._render_spherical:
489533
camtoworld = self.camtoworlds[cam_idx]
490534
rays = camera_utils.cast_spherical_rays(
491-
camtoworld, self.height, self.width, self.near, self.far, xnp=np)
535+
camtoworld, self.height[cam_idx], self.width[cam_idx], self.near, self.far, xnp=np)
492536
return utils.Batch(rays=rays)
493537
else:
494538
# Generate rays for all pixels in the image.
495539
pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
496-
self.width, self.height)
540+
self.width[cam_idx], self.height[cam_idx])
497541
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx)
498542

499543
def _next_test(self) -> utils.Batch:
@@ -593,13 +637,15 @@ def _load_renderings(self, config):
593637
inds = np.argsort(image_names)
594638
image_names = [image_names[i] for i in inds]
595639
poses = poses[inds]
640+
pixtocam = pixtocam[inds]
641+
distortion_params = [distortion_params[i] for i in inds]
642+
camtype = [camtype[i] for i in inds]
596643

597644
# Scale the inverse intrinsics matrix by the image downsampling factor.
598645
pixtocam = pixtocam @ np.diag([factor, factor, 1.])
599646
self.pixtocams = pixtocam.astype(np.float32)
600-
self.focal = 1. / self.pixtocams[0, 0]
647+
self.focal = 1. / self.pixtocams[..., 0, 0]
601648
self.distortion_params = distortion_params
602-
self.camtype = camtype
603649

604650
raw_testscene = False
605651
if config.rawnerf_mode:
@@ -706,6 +752,12 @@ def _load_renderings(self, config):
706752
# All per-image quantities must be re-indexed using the split indices.
707753
images = images[indices]
708754
poses = poses[indices]
755+
self.pixtocams = self.pixtocams[indices]
756+
self.focal = self.focal[indices]
757+
list_indices = np.where(indices)[0] if indices.dtype == np.bool_ else indices
758+
self.distortion_params = [self.distortion_params[i] for i in list_indices]
759+
self.camtype = [camtype[i] for i in list_indices]
760+
assert len(self.camtype) == len(self.pixtocams)
709761
if self.exposures is not None:
710762
self.exposures = self.exposures[indices]
711763
if config.rawnerf_mode:

0 commit comments

Comments
 (0)