You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, I am trying to reproduce the Figure 1 in Klein et al. (2024). I have trouble implementing the fourth cost function, where $h(z) = \frac{1}{2} |z|_2^2 + \gamma \frac{1}{2}|b^\perp z|_2^2$. I am using orthogonal_regularizer = regularizers.Quadratic(A, is_factor=True, is_complement=True)
as regularizer that I then use for the cost function l2_b_cost = costs.RegTICost(orthogonal_regularizer, lam=1000)
which I then pass to map_l2_b = entropic_map(xs, y, cost_fn = l2_b_cost).
The entropic_maps function is taken from a tutorial.
The resulting transportation looks like this (red arrow is direction of b, orange points are the transported points)
which does not penalize the orthogonal transportation as expected, i.e. for $b=[0,1]$ the points should be transported parallel to y-axis but they are not. I also tried using orthogonal_regularizer = regularizers.Orthogonal(f=l2_regularizer, A=P), but that did not work either.
I was wondering where my mistake is and what I can do for a proper implementation.
Code to reproduce:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ott.geometry import costs, pointcloud, regularizers
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
def plot_displacements(xs, xt, title, g, b_vec=None):
limitx = 5
limity = 3
plt.figure(figsize=(4.5, 4.5))
# Create meshgrid
grid_x, grid_y = jnp.meshgrid(jnp.linspace(0, limitx, 100), jnp.linspace(0, limity, 100))
grid_points = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=1)
g_vals = jax.vmap(g)(grid_points)
g_vals = g_vals.reshape(grid_x.shape)
# Plot level lines of g
plt.contour(grid_x, grid_y, g_vals, levels=20, linewidths=0.7, colors='gray', linestyles='solid')
# Plot displacements
for x, x_t in zip(xs, xt):
plt.plot([x[0], x_t[0]], [x[1], x_t[1]], 'k-', lw=1)
# plt.gca().annotate('', xy=x_t, xytext=x, arrowprops=dict(arrowstyle='->', color='k', lw=1))
plt.plot(x[0], x[1], 's', color='cornflowerblue', markersize=8, mew = 1.2, mec='black')
plt.plot(x_t[0], x_t[1], 'o', color='orange', markersize=6, mew = 1.2, mec='black')
# Plot b vector
if b_vec is not None:
b_unit = b_vec / jnp.linalg.norm(b_vec)
base = jnp.array([limitx-1, limity/2])
plt.arrow(base[0], base[1], b_unit[0], b_unit[1],
width=0.05, head_width=0.2, color='red')
plt.xlim(0, limitx)
plt.ylim(0, limity)
plt.grid(False)
plt.gca().set_aspect('equal')
plt.title(title)
plt.tight_layout()
plt.show()
# Sample random base measure
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
num_pts = 20
# xs = jax.random.normal(subkey, (num_pts, 2))
xs = jax.random.uniform(subkey, (num_pts, 2), minval=jnp.array([1, 0.5]), maxval=jnp.array([5, 3])) # quadrant 1 data
# dummy concave and smooth function
def g(x):
mu = jnp.array([0.0, 0.0])
Sigma = jnp.array([[3, -0.5],
[-0.5, 1.5]])
inv_Sigma = jnp.linalg.inv(Sigma)
diff = x - mu
return -0.5 * diff.T @ inv_Sigma @ diff
grad_g = jax.grad(g)
solver = jax.jit(sinkhorn.Sinkhorn())
def entropic_map(x, y, cost_fn: costs.TICost) -> jnp.ndarray:
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)
output = solver(linear_problem.LinearProblem(geom))
dual_potentials = output.to_dual_potentials()
return dual_potentials.transport
A = jnp.atleast_2d(jnp.array([0,1]) )
P = jnp.eye(A.shape[0]) - A.T @ jnp.linalg.inv(A @ A.T) @ A
# Define the regularizers
l1_regularizer = regularizers.L1()
l2_regularizer = regularizers.SqL2()
orthogonal_regularizer = regularizers.Quadratic(A, is_factor=True, is_complement=True)
# orthogonal_regularizer = regularizers.Orthogonal(f=l2_regularizer, A=P)
# Get the target points
grad_g = jax.jit(jax.vmap(jax.grad(g)))
y = xs + grad_g(xs)
# Initiate cost classes
l2_cost = costs.SqEuclidean()
l2_l1_cost = costs.RegTICost(l1_regularizer, lam = 100)
l2_b_cost = costs.RegTICost(orthogonal_regularizer, lam=1000)
# Create maps for cost functions
map_l2 = entropic_map(xs, y, cost_fn = l2_cost)
map_l2_l1 = entropic_map(xs, y, cost_fn = l2_l1_cost)
map_l2_b = entropic_map(xs, y, cost_fn = l2_b_cost)
# Plot displacements
plot_displacements(xs, map_l2(xs), r"$h(z) = \frac{1}{2} \|z\|_2^2$", g)
plot_displacements(xs, map_l2_l1(xs), r"$h(z) = \frac{1}{2} \|z\|_2^2 + \gamma \|z\|_1$", g)
# Transport for map_l2_b should mostly parallel to the vector b, i.e. for b=[0,1] along the y-axis
plot_displacements(xs, map_l2_b(xs), r"$h(z) = \frac{1}{2} \|z\|_2^2 + \gamma \|b^\perp z\|^2$", g, b_vec=jnp.array([0,1]))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there, I am trying to reproduce the Figure 1 in Klein et al. (2024). I have trouble implementing the fourth cost function, where$h(z) = \frac{1}{2} |z|_2^2 + \gamma \frac{1}{2}|b^\perp z|_2^2$ . I am using

$b=[0,1]$ the points should be transported parallel to y-axis but they are not. I also tried using
orthogonal_regularizer = regularizers.Quadratic(A, is_factor=True, is_complement=True)
as regularizer that I then use for the cost function
l2_b_cost = costs.RegTICost(orthogonal_regularizer, lam=1000)
which I then pass to
map_l2_b = entropic_map(xs, y, cost_fn = l2_b_cost)
.The entropic_maps function is taken from a tutorial.
The resulting transportation looks like this (red arrow is direction of b, orange points are the transported points)
which does not penalize the orthogonal transportation as expected, i.e. for
orthogonal_regularizer = regularizers.Orthogonal(f=l2_regularizer, A=P)
, but that did not work either.I was wondering where my mistake is and what I can do for a proper implementation.
Code to reproduce:
Beta Was this translation helpful? Give feedback.
All reactions