-
I'm new to jax, I want to implement a gaussian splat render (2D for now). my code works, but it's extremely slow: import math
import jax
import jax.numpy as jnp
from PIL import Image
import numpy as np
gaussian_count = 30
angles = jax.random.uniform(jax.random.key(0), minval=-math.pi, maxval=math.pi, shape=(gaussian_count))
print(angles)
scales = jax.random.uniform(jax.random.key(0), minval=0.0, maxval=5.0, shape=(gaussian_count,2))
print(scales)
cos = jnp.cos(angles)
print(cos)
sin = jnp.sin(angles)
print(sin)
cos_msin = jnp.column_stack((cos, -sin))
print(cos_msin)
sin_cos = jnp.column_stack((sin, cos))
print(sin_cos)
rotation_matrix = jnp.stack((cos_msin, sin_cos), axis=1)
print(rotation_matrix)
scaling_matrices = jnp.zeros(shape=(gaussian_count, 2, 2))
scaling_matrices = scaling_matrices.at[:, 0, 0].set( scales[:, 0])
scaling_matrices = scaling_matrices.at[:, 1, 1].set( scales[:, 1])
print(scaling_matrices)
T = jnp.matmul(rotation_matrix, scaling_matrices)
print(T)
cov = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
print(cov)
means = jax.random.uniform(jax.random.key(0), minval=0.0, maxval=512.0, shape=(gaussian_count,2))
colors = jax.random.uniform(jax.random.key(0), minval=0.0, maxval=1.0, shape=(gaussian_count,4))
#colors = jnp.ones( shape=(gaussian_count,4))
det = cov[:,0,0]*cov[:,1,1] - cov[:,0,1]*cov[:,0,1]
det_inv = 1.0 / det
print(det)
conic = jnp.zeros(shape=(gaussian_count, 3))
conic = conic.at[:, 0].set(cov[:,1,1]*det_inv)
conic = conic.at[:, 1].set(-cov[:,0,1]*det_inv)
conic = conic.at[:, 2].set(cov[:,0,0]*det_inv)
print(conic)
mid = 0.5 * (cov[:,0,0] + cov[:,1,1] )
lambda1 = mid + jnp.sqrt(jnp.maximum(0.1, mid*mid - det))
lambda2 = mid - jnp.sqrt(jnp.maximum(0.1, mid * mid - det))
radius = jnp.ceil(3.0 * jnp.sqrt(jnp.maximum(lambda1, lambda2)))
print(radius)
tiles = [[[] for _ in range(32)] for _ in range(32)]
minx = jnp.floor( jnp.maximum(0.0, jnp.floor(means[:,0] - radius)) / 16)
miny =jnp.floor( jnp.maximum(0.0, jnp.floor(means[:,1] - radius)) / 16)
maxx = jnp.floor( jnp.minimum(511, jnp.ceil(means[:,0] + radius)) / 16)
maxy = jnp.floor (jnp.minimum(511, jnp.ceil(means[:,1] + radius)) / 16)
for i in range(radius.shape[0]):
for y in range(int(miny[i]),int(maxy[i]+1)):
for x in range(int(minx[i]), int(maxx[i]+1)):
tiles[y][x].append(i)
print(tiles)
print(minx, maxx)
image = jnp.zeros(shape=(512,512, 3))
for y in range(0, 512):
tileY = int(y / 16)
for x in range(0,512):
tileX = int( x / 16)
T = 1.0
for t in tiles[tileY][tileX]:
if det[t] == 0.0:
continue
d = jnp.array([x,y]) - means[t]
power = -0.5 * (conic[t,0] * d[0] * d[0] + conic[t,2] * d[1] * d[1]) - conic[t,1] * d[0] * d[1]
if power > 0.0:
continue
alpha = jnp.minimum(0.99, colors[t,3] * jnp.exp(power))
if alpha < 1.0 / 255.0:
continue
test_T = T * (1 - alpha)
if test_T < 0.0001:
continue
#print(colors[t,:3] * alpha * T)
#print( image.at[y,x,:].get() + colors[t,:3] * alpha * T)
image = image.at[y,x,:].set( image.at[y,x,:].get() + colors[t,:3] * alpha * T)
T = test_T
#print(image)
image_uint8 = (image * 255).astype(jnp.uint8)
image_np = np.array(image_uint8) # Convert to NumPy array
img = Image.fromarray(image_np)
img.save("output_image.png")
it is slow due to the for loops I can't get rid off. in the for loop, I can't use any jax operators on the tensors, because the calculation is only carried out on selected gaussian splats based on visibility, not all of them. Essentially the tensors become look up tables, and the calculation is carried out one pixel / one gaussian at a time. is there a way to optimize that? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You identified the main issue here: any time you find yourself writing def f_loop(x: Array) -> Array:
y = []
for i in range(len(x)):
y.append(jnp.exp(0.5 * x[i]))
return jnp.array(y) you could write something like this (taking advantage of the fact that def f_vectorized(x: Array) -> Array:
return jnp.exp(0.5 * x) Now, vectorized operations don't admit the kind of early-exit conditions that you have in your code. In many cases, though, I suspect you'll find that the cost of the control flow overhead assotiated early exit outweighs the cost of the redundant calculation you're trying to avoid, especially if you're running on GPU or another accelerator. So I'd suggest rewriting your loopy code in terms of vectorized operations, without the early exit conditions. What do you think? |
Beta Was this translation helpful? Give feedback.
You identified the main issue here: any time you find yourself writing
for
loops over array entries, your code is likely to be slow (this is true in JAX as it is in NumPy). The solution generally is to re-express your computations in terms of vectorized operations. In the simplest cases, it means that instead of something like this:you could write something like this (taking advantage of the fact that
jax.numpy
operations work element-wise):Now, vectorized operations don't admit the kind of early-exit con…