-
Hi, I'm struggling with the following different autograd behaviors between def generator() -> Iterator[tuple[chex.Array, chex.Array]]:
rng = jax.random.PRNGKey(0)
while True:
rng, k1, k2 = jax.random.split(rng, num=3)
x = jax.random.uniform(k1, minval=0.0, maxval=10.0)
y = 10.0 * x + jax.random.normal(k2)
yield np.array([x, y])
def f(theta: chex.Array, x: chex.Array) -> chex.Array:
return x * theta
theta = jax.random.normal(jax.random.PRNGKey(42))
init_learning_rate = jnp.array(0.1)
meta_learning_rate = jnp.array(0.03)
opt = optax.inject_hyperparams(optax.rmsprop)(learning_rate=init_learning_rate)
meta_opt = optax.adam(learning_rate=meta_learning_rate)
def loss(inner, x, y):
return optax.l2_loss(y, f(inner, x))
def step(inner, state, x, y):
grad = jax.grad(loss)(inner, x, y)
updates, state = opt.update(grad, state)
inner = optax.apply_updates(inner, updates)
return inner, state
@jax.jit
def outer_loss(outer, inner, state, samples):
state.hyperparams['learning_rate'] = jax.nn.sigmoid(outer)
for x, y in samples[:-1]:
inner, state = step(inner, state, x, y)
x, y = samples[-1]
return loss(inner, x, y), (inner, state)
@jax.jit
def outer_step(outer, inner, meta_state, state, samples):
grad, (inner, state) = jax.grad(outer_loss, has_aux=True)(outer, inner, state, samples)
meta_updates, meta_state = meta_opt.update(grad, meta_state)
outer = optax.apply_updates(outer, meta_updates)
return outer, inner, meta_state, state
g = generator()
state = opt.init(theta)
# inverse sigmoid, to match the value we initialized the inner optimizer with.
eta = -np.log(1. / init_learning_rate - 1)
meta_state = meta_opt.init(eta)
N = 7
learning_rates = []
thetas = []
for i in range(2000):
samples = [next(g) for _ in range(N)]
eta, theta, meta_state, state = outer_step(eta, theta, meta_state, state, samples)
learning_rates.append(jax.nn.sigmoid(eta))
thetas.append(theta) If I understand correctly, the # for x, y in samples[:-1]:
# inner, state = step(inner, state, x, y)
def f(carry, sample):
inner, state = carry
inner, state = step(inner, state, sample[0], sample[1])
return (inner, state), None
(inner, state), _ = jax.lax.scan(f, (inner, state), samples[:-1]) However, it changes the behavior so differently from the original. Specifically, the original behavior is like but the modified one is like I think autograd works differently, so the behaviors are different, but I could not figure out the details. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, I think the main problem here is that the code scans over list. (inner, state), _ = jax.lax.scan(f, (inner, state), samples[:-1]) to (inner, state), _ = jax.lax.scan(f, (inner, state), jnp.stack(samples)[:-1]) Here is the colab for reference. |
Beta Was this translation helpful? Give feedback.
Hi,
I think the main problem here is that the code scans over list.
samples
is a list of array,jax.lax.scan
does not scan over a list (see more #13898). Instead, it looks at each element in the list and scan over them individually. The code can obtain the expected result by changingto
Here is the colab for reference.