@@ -329,7 +329,11 @@ def __init__(self,
329
329
if len (self .pixtocams .shape ) == 2 :
330
330
self .pixtocams = np .repeat (self .pixtocams [None ], self ._n_examples , 0 )
331
331
if not isinstance (self .focal , np .ndarray ):
332
- self .focal = np .full ((self ._n_examples ,), self .focal , dtype = np .int32 )
332
+ self .focal = np .full ((self ._n_examples ,), self .focal , dtype = np .float32 )
333
+ if not isinstance (self .near , np .ndarray ):
334
+ self .near = np .full ((self ._n_examples ,), self .near , dtype = np .float32 )
335
+ if not isinstance (self .height , np .ndarray ):
336
+ self .far = np .full ((self ._n_examples ,), self .far , dtype = np .float32 )
333
337
if not isinstance (self .width , np .ndarray ):
334
338
self .width = np .full ((self ._n_examples ,), self .width , dtype = np .int32 )
335
339
if not isinstance (self .height , np .ndarray ):
@@ -450,10 +454,11 @@ def _make_ray_batch(self,
450
454
"""
451
455
452
456
broadcast_scalar = lambda x : np .broadcast_to (x , pix_x_int .shape )[..., None ]
457
+ idx = 0 if self .render_path else cam_idx
453
458
ray_kwargs = {
454
459
'lossmult' : broadcast_scalar (1. ) if lossmult is None else lossmult ,
455
- 'near' : broadcast_scalar (self .near ),
456
- 'far' : broadcast_scalar (self .far ),
460
+ 'near' : broadcast_scalar (self .near [ idx ] ),
461
+ 'far' : broadcast_scalar (self .far [ idx ] ),
457
462
'cam_idx' : broadcast_scalar (cam_idx ),
458
463
}
459
464
# Collect per-camera information needed for each ray.
@@ -532,7 +537,11 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch:
532
537
if self ._render_spherical :
533
538
camtoworld = self .camtoworlds [cam_idx ]
534
539
rays = camera_utils .cast_spherical_rays (
535
- camtoworld , self .height [cam_idx ], self .width [cam_idx ], self .near , self .far , xnp = np )
540
+ camtoworld ,
541
+ self .height [cam_idx ], self .width [cam_idx ],
542
+ self .near [cam_idx ],
543
+ self .far [cam_idx ],
544
+ xnp = np )
536
545
return utils .Batch (rays = rays )
537
546
else :
538
547
# Generate rays for all pixels in the image.
0 commit comments