-
Hello JAX people! I'm trying to write a simulation using JAX. For outside design reasons, I'm tied to a hierarchical data structure like this for my model state:
I would ideally like this state to be in the shape of a pytree and use JAX's features to perform operations on it. I coded up an example using My main question is: from __future__ import annotations
from dataclasses import dataclass
import jax
import equinox as eqx
@dataclass
class Component:
metadata: str
data: float
class Module(eqx.Module):
components: list[Component]
submodules: list[Module]
def get_single_level_module():
return Module(
components=[
Component(
metadata="{name: 'A'}",
data=0.5
),
Component(
metadata="{name: 'B'}",
data=-1.0
),
],
submodules=[]
)
def test_map_nested_module():
nested_module = get_single_level_module()
nested_module.submodules.append(get_single_level_module())
print("Pre-processed module:")
print(nested_module)
def double_data(component: Component) -> Component:
return Component(
metadata=component.metadata,
data=component.data*2
)
processed_module = jax.tree.map(
double_data,
nested_module,
is_leaf=lambda node: isinstance(node, Component)
)
print("Processed module:")
print(processed_module)
if __name__ == "__main__":
test_map_nested_module() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
vmap
operates on arrays of values within pytrees, and I don't see any arrays of values in your pytrees here, sovmap
is probably not applicable.