-
I have this block dataclass that contains both numerical attributes that are easy to stack and a static attribute Ideally, the stacking operations below should give something like {
"width": [1.0, 3.0],
"height": [2.0, 4.0],
"name": ["block1", "block2"]
} but the stacking throws. I wonder if there's a way to make the from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Block:
width: float
height: float
name: str = field(metadata=dict(static=True))
def stack_fn(*sequence: Block) -> Block:
return jax.tree.map(lambda *leaves: jnp.stack(leaves, axis=0), *sequence)
def double_block_size(block: Block) -> Block:
return Block(width=block.width * 2, height=block.height * 2, name=block.name)
blocks = (
Block(width=1.0, height=2.0, name="block1"),
Block(width=3.0, height=4.0, name="block2"),
)
# stacking throws "ValueError: Mismatch custom dataclass node data: ('block1',) != ('block2',)"
stacked_blocks = stack_fn(*blocks)
doubled_blocks = jax.vmap(double_block_size)(blocks) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I don't think def stack_fn(*blocks):
return Block(*zip(*((b.width, b.height, b.name) for b in blocks))) |
Beta Was this translation helpful? Give feedback.
I don't think
jax.tree
utilities will help in doing the operation you have in mind, because there is no way to manipulate static metadata via tree flattening. Depending on your real-world case, you might try doing it directly using Python builtins, i.e. something like this: