Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 06b0195

Browse files
committed
Fix focal lenght dtype + add support for different znear,zfar
1 parent 99125d6 commit 06b0195

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

internal/datasets.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ def __init__(self,
329329
if len(self.pixtocams.shape) == 2:
330330
self.pixtocams = np.repeat(self.pixtocams[None], self._n_examples, 0)
331331
if not isinstance(self.focal, np.ndarray):
332-
self.focal = np.full((self._n_examples,), self.focal, dtype=np.int32)
332+
self.focal = np.full((self._n_examples,), self.focal, dtype=np.float32)
333+
if not isinstance(self.near, np.ndarray):
334+
self.near = np.full((self._n_examples,), self.near, dtype=np.float32)
335+
if not isinstance(self.height, np.ndarray):
336+
self.far = np.full((self._n_examples,), self.far, dtype=np.float32)
333337
if not isinstance(self.width, np.ndarray):
334338
self.width = np.full((self._n_examples,), self.width, dtype=np.int32)
335339
if not isinstance(self.height, np.ndarray):
@@ -450,10 +454,11 @@ def _make_ray_batch(self,
450454
"""
451455

452456
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
457+
idx = 0 if self.render_path else cam_idx
453458
ray_kwargs = {
454459
'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult,
455-
'near': broadcast_scalar(self.near),
456-
'far': broadcast_scalar(self.far),
460+
'near': broadcast_scalar(self.near[idx]),
461+
'far': broadcast_scalar(self.far[idx]),
457462
'cam_idx': broadcast_scalar(cam_idx),
458463
}
459464
# Collect per-camera information needed for each ray.
@@ -532,7 +537,11 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch:
532537
if self._render_spherical:
533538
camtoworld = self.camtoworlds[cam_idx]
534539
rays = camera_utils.cast_spherical_rays(
535-
camtoworld, self.height[cam_idx], self.width[cam_idx], self.near, self.far, xnp=np)
540+
camtoworld,
541+
self.height[cam_idx], self.width[cam_idx],
542+
self.near[cam_idx],
543+
self.far[cam_idx],
544+
xnp=np)
536545
return utils.Batch(rays=rays)
537546
else:
538547
# Generate rays for all pixels in the image.

0 commit comments

Comments
 (0)