Replies: 5 comments 4 replies
-
Hi, I think you can do similar to list of arrays by stacking all the values having the same key into a single dictionary. import jax
import jax.numpy as jnp
# let's say tranformer block is like this
def dummy_transformer_block(x, param):
return param["W"] * x
# consider list of dictionary of paramters
params = [dict(W=jnp.ones((10,))), dict(W=jnp.ones((10,)))]
# make a single dictionary. for each key, its value is stack all values in the the list
new_params = {}
keys = list(params[0].keys())
for key in keys:
new_params[key] = jnp.stack([p[key] for p in params])
# `f` function for scan
def body_fn(x, param):
x = dummy_transformer_block(x, param)
return x, x
x = jnp.ones((10,))
jax.lax.scan(f=body_fn,
init=x,
xs=new_params) If you're familar to Equinox, you can check out a similar feature here. |
Beta Was this translation helpful? Give feedback.
-
It sounds like you want to convert an array-of-structs into a struct-of-arrays. There is a relevant discussion here: #14073 For your case, you might do something like this: import jax
import jax.numpy as jnp
blocks = [
{'a': 1, 'b': 2},
{'a': 3, 'b': 4},
{'a': 5, 'b': 6}
]
def f_loop(blocks):
(a, b) = (0, 0)
for block in blocks:
a += block['a']
b += block['b']
return a, b
a, b = f_loop(blocks)
print(a, b)
# 9 12
def f_scan(blocks):
blocks_transposed = {
k: jnp.stack([block[k] for block in blocks])
for k in blocks[0].keys()
}
def body_fun(carry, block):
a, b = carry
a += block['a']
b += block['b']
return (a, b), block
result, _ = jax.lax.scan(body_fun, (0, 0), blocks_transposed)
return result
a, b = f_scan(blocks)
print(a, b)
# 9 12 |
Beta Was this translation helpful? Give feedback.
-
Since def access(block, path):
for p in path:
block = block[p.key]
return block
def collector(path, _):
return jnp.stack([access(block, path) for block in blocks])
blocks_transposed = jax.tree_util.tree_map_with_path(collector, blocks[0])
def f(x, block):
x = transformer_block(x, **block, n_head=n_head)
return x, None
x, _ = jax.lax.scan(f, x, blocks_transposed) |
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp, I am having a similar problem, and I was wondering if I could apply the same reasoning here? I have a list of some import chex
import jax
import jax.numpy as jnp
@chex.dataclass
class Obj:
x: chex.ArrayDevice
@jax.jit
def sum(self):
return jnp.sum(self.x)
def f(carry, x):
carry = jnp.sum(x) * carry
return carry, carry
xs = jnp.ones((5, 2))
x0 = jnp.array([1., 2.])
acc, _ = jax.lax.scan(f, x0, xs=xs)
print(acc) # [32. 64.], this works
def g(carry, x):
carry = x.sum() * carry
return carry, carry
xs = [Obj(x=jnp.ones((2,))) for _ in range(5)]
acc, _ = jax.lax.scan(g, x0, xs=xs)
print(acc) # AttributeError: 'list' object has no attribute 'sum' |
Beta Was this translation helpful? Give feedback.
-
I implemented the same thing in my llama-2-jax project. The related code are: # https://docs.liesel-project.org/en/v0.1.4/_modules/liesel/goose/pytree.html#stack_leaves
def stack_leaves(pytrees, axis: int=0):
'''
Stack the leaves of one or more PyTrees along a new axis.
Args:
pytrees: One or more PyTrees.
axis (int, optional): The axis along which the arrays will be stacked. Default is 0.
Returns:
The PyTree with its leaves stacked along the new axis.
'''
return jax.tree_util.tree_map(lambda *xs: jnp.stack(xs, axis=axis), *pytrees)
# https://gist.github.com/willwhitney/dd89cac6a5b771ccff18b06b33372c75?permalink_comment_id=4634557#gistcomment-4634557
def unstack_leaves(pytrees):
'''
Unstack the leaves of a PyTree.
Args:
pytrees: A PyTree.
Returns:
A list of PyTrees, where each PyTree has the same structure as the input PyTree, but each leaf contains only one part of the original leaf.
'''
leaves, treedef = jax.tree_util.tree_flatten(pytrees)
return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]
params_decoder_blocks = stack_leaves(params_decoder_blocks_old)
@partial(jax.jit, static_argnames=('model_config',))
def forward_decoder(params: Decoder, seq: Array, attn_mask: Array, *, key: rand.KeyArray, model_config: ModelConfig) -> Array:
def inner(state, input_):
key, seq = state
key, subkey = split_key_nullable(key)
seq = forward_decoder_block(input_, seq, attn_mask, key=subkey, model_config=model_config)
return (key, seq), None
(key, seq), _ = jax.lax.scan(inner, (key, seq), params)
return seq |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
This question is a generalization of #13898. Basically, I am trying to rewrite the following structure with
jax.lax.scan
, whereblocks
is a list of dictionaries.I tried the following, but it doesn't work because
jax.lax.scan
iterates over the innermost dimension.How do I transform
blocks
from a list of directories into a directory of lists?Beta Was this translation helpful? Give feedback.
All reactions