Skip to content

Commit 95f1a65

Browse files
committed
Merge gsplat implementation from @jclarkk pull request
- Merged jclarkk's code onto more recent Trellis code - See microsoft#73
1 parent 0a19f9b commit 95f1a65

File tree

5 files changed

+136
-10
lines changed

5 files changed

+136
-10
lines changed

setup.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Read Arguments
2-
TEMP=`getopt -o h --long help,new-env,basic,train,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,demo -n 'setup.sh' -- "$@"`
2+
TEMP=`getopt -o h --long help,new-env,basic,train,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,gsplat,demo -n 'setup.sh' -- "$@"`
33

44
eval set -- "$TEMP"
55

@@ -17,6 +17,7 @@ ERROR=false
1717
MIPGAUSSIAN=false
1818
KAOLIN=false
1919
NVDIFFRAST=false
20+
GSPLAT=false
2021
DEMO=false
2122

2223
if [ "$#" -eq 1 ] ; then
@@ -37,6 +38,7 @@ while true ; do
3738
--mipgaussian) MIPGAUSSIAN=true ; shift ;;
3839
--kaolin) KAOLIN=true ; shift ;;
3940
--nvdiffrast) NVDIFFRAST=true ; shift ;;
41+
--gsplat) GSPLAT=true ; shift ;;
4042
--demo) DEMO=true ; shift ;;
4143
--) shift ; break ;;
4244
*) ERROR=true ; break ;;
@@ -63,6 +65,7 @@ if [ "$HELP" = true ] ; then
6365
echo " --mipgaussian Install mip-splatting"
6466
echo " --kaolin Install kaolin"
6567
echo " --nvdiffrast Install nvdiffrast"
68+
echo " --gsplat Install gsplat"
6669
echo " --demo Install all dependencies for demo"
6770
return
6871
fi
@@ -258,3 +261,11 @@ fi
258261
if [ "$DEMO" = true ] ; then
259262
pip install gradio==4.44.1 gradio_litmodel3d==0.0.1
260263
fi
264+
265+
if [ "$GSPLAT" = true ] ; then
266+
if [ "$PLATFORM" = "cuda" ] ; then
267+
pip install git+https://github.com/nerfstudio-project/gsplat
268+
else
269+
echo "[GSPLAT] Unsupported platform: $PLATFORM"
270+
fi
271+
fi

trellis/renderers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__attributes = {
44
'OctreeRenderer': 'octree_renderer',
55
'GaussianRenderer': 'gaussian_render',
6+
'GSplatRenderer': 'gsplat_renderer',
67
'MeshRenderer': 'mesh_renderer',
78
}
89

@@ -28,4 +29,5 @@ def __getattr__(name):
2829
if __name__ == '__main__':
2930
from .octree_renderer import OctreeRenderer
3031
from .gaussian_render import GaussianRenderer
32+
from .gsplat_renderer import GSplatRenderer
3133
from .mesh_renderer import MeshRenderer

trellis/renderers/gsplat_renderer.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import gsplat as gs
2+
import numpy as np
3+
import torch
4+
import torch.nn.functional as F
5+
from easydict import EasyDict as edict
6+
7+
8+
class GSplatRenderer:
9+
def __init__(self, rendering_options={}) -> None:
10+
self.pipe = edict({
11+
"kernel_size": 0.1,
12+
"convert_SHs_python": False,
13+
"compute_cov3D_python": False,
14+
"scale_modifier": 1.0,
15+
"debug": False,
16+
"use_mip_gaussian": True
17+
})
18+
self.rendering_options = edict({
19+
"resolution": None,
20+
"near": None,
21+
"far": None,
22+
"ssaa": 1,
23+
"bg_color": 'random',
24+
})
25+
self.rendering_options.update(rendering_options)
26+
self.bg_color = None
27+
28+
def render(
29+
self,
30+
gaussian,
31+
extrinsics: torch.Tensor,
32+
intrinsics: torch.Tensor,
33+
colors_overwrite: torch.Tensor = None
34+
) -> edict:
35+
36+
resolution = self.rendering_options["resolution"]
37+
ssaa = self.rendering_options["ssaa"]
38+
39+
if self.rendering_options["bg_color"] == 'random':
40+
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
41+
if np.random.rand() < 0.5:
42+
self.bg_color += 1
43+
else:
44+
self.bg_color = torch.tensor(
45+
self.rendering_options["bg_color"],
46+
dtype=torch.float32,
47+
device="cuda"
48+
)
49+
50+
height = resolution * ssaa
51+
width = resolution * ssaa
52+
53+
# Set up background color
54+
if self.rendering_options["bg_color"] == 'random':
55+
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
56+
if np.random.rand() < 0.5:
57+
self.bg_color += 1
58+
else:
59+
self.bg_color = torch.tensor(
60+
self.rendering_options["bg_color"],
61+
dtype=torch.float32,
62+
device="cuda"
63+
)
64+
65+
Ks_scaled = intrinsics.clone()
66+
Ks_scaled[0, 0] *= width
67+
Ks_scaled[1, 1] *= height
68+
Ks_scaled[0, 2] *= width
69+
Ks_scaled[1, 2] *= height
70+
Ks_scaled = Ks_scaled.unsqueeze(0)
71+
72+
near_plane = 0.01
73+
far_plane = 1000.0
74+
75+
# Rasterize with gsplat
76+
render_colors, render_alphas, meta = gs.rasterization(
77+
means=gaussian.get_xyz,
78+
quats=F.normalize(gaussian.get_rotation, dim=-1),
79+
scales=gaussian.get_scaling / intrinsics[0, 0],
80+
opacities=gaussian.get_opacity.squeeze(-1),
81+
colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid(
82+
gaussian.get_features.squeeze(1)).unsqueeze(0),
83+
viewmats=extrinsics.unsqueeze(0),
84+
Ks=Ks_scaled,
85+
width=width,
86+
height=height,
87+
near_plane=near_plane,
88+
far_plane=far_plane,
89+
radius_clip=3.0,
90+
eps2d=0.3,
91+
render_mode="RGB",
92+
backgrounds=self.bg_color.unsqueeze(0),
93+
camera_model="pinhole"
94+
)
95+
96+
rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1)
97+
98+
# Apply supersampling if needed
99+
if ssaa > 1:
100+
rendered_image = F.interpolate(
101+
rendered_image[None],
102+
size=(resolution, resolution),
103+
mode='bilinear',
104+
align_corners=False,
105+
antialias=True
106+
).squeeze()
107+
108+
return edict({'color': rendered_image})

trellis/utils/postprocessing_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def to_glb(
405405
texture_size: int = 1024,
406406
debug: bool = False,
407407
verbose: bool = True,
408+
gs_renderer='gsplat',
408409
) -> trimesh.Trimesh:
409410
"""
410411
Convert a generated asset to a glb file.
@@ -418,6 +419,7 @@ def to_glb(
418419
texture_size (int): Size of the texture.
419420
debug (bool): Whether to print debug information.
420421
verbose (bool): Whether to print progress.
422+
gs_renderer (str): Name of the renderer to use for gaussian splatting rendering.
421423
"""
422424
vertices = mesh.vertices.cpu().numpy()
423425
faces = mesh.faces.cpu().numpy()
@@ -433,14 +435,14 @@ def to_glb(
433435
fill_holes_resolution=1024,
434436
fill_holes_num_views=1000,
435437
debug=debug,
436-
verbose=verbose,
438+
verbose=verbose
437439
)
438440

439441
# parametrize mesh
440442
vertices, faces, uvs = parametrize_mesh(vertices, faces)
441443

442444
# bake texture
443-
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
445+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100, gs_renderer=gs_renderer)
444446
masks = [np.any(observation > 0, axis=-1) for observation in observations]
445447
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
446448
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]

trellis/utils/render_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import utils3d
55
from PIL import Image
66

7-
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
7+
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer, GSplatRenderer
88
from ..representations import Octree, Gaussian, MeshExtractResult
99
from ..modules import sparse as sp
1010
from .random_utils import sphere_hammersley_sequence
@@ -40,7 +40,7 @@ def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
4040
return extrinsics, intrinsics
4141

4242

43-
def get_renderer(sample, **kwargs):
43+
def get_renderer(sample, gs_renderer='gsplat', **kwargs):
4444
if isinstance(sample, Octree):
4545
renderer = OctreeRenderer()
4646
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
@@ -50,7 +50,10 @@ def get_renderer(sample, **kwargs):
5050
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
5151
renderer.pipe.primitive = sample.primitive
5252
elif isinstance(sample, Gaussian):
53-
renderer = GaussianRenderer()
53+
if gs_renderer == 'gsplat':
54+
renderer = GSplatRenderer()
55+
else:
56+
renderer = GaussianRenderer()
5457
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
5558
renderer.rendering_options.near = kwargs.get('near', 0.8)
5659
renderer.rendering_options.far = kwargs.get('far', 1.6)
@@ -69,8 +72,8 @@ def get_renderer(sample, **kwargs):
6972
return renderer
7073

7174

72-
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
73-
renderer = get_renderer(sample, **options)
75+
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, gs_renderer='gsplat', **kwargs):
76+
renderer = get_renderer(sample, gs_renderer, **options)
7477
rets = {}
7578
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
7679
if isinstance(sample, MeshExtractResult):
@@ -100,14 +103,14 @@ def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2
100103
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
101104

102105

103-
def render_multiview(sample, resolution=512, nviews=30):
106+
def render_multiview(sample, resolution=512, nviews=30, gs_renderer='gsplat'):
104107
r = 2
105108
fov = 40
106109
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
107110
yaws = [cam[0] for cam in cams]
108111
pitchs = [cam[1] for cam in cams]
109112
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
110-
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
113+
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}, gs_renderer=gs_renderer)
111114
return res['color'], extrinsics, intrinsics
112115

113116

0 commit comments

Comments
 (0)