How do I pass objects of a custom class into functions that use jit acceleration? #15351
Replies: 1 comment 1 reply
-
Hi - thanks for the question! The way to do this is to register your class as a custom pytree. You can see details here: https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees But please note that even if your class were a pytree, your code as written will not be compatible with JAX transforms like JIT, because Addressing both these issues, you could modify your code like this: import jax
from functools import partial
from jax import tree_util
class Replay_buffer():
def __init__(self):
self.pri = 0
def update_pri(self,a,b):
self.pri=a+b
def sample(self,c,d):
self.pri = c+d
def flatten_replay(obj):
children = (obj.pri,)
aux_data = ()
return children, aux_data
def unflatten_replay(aux_data, children):
obj = Replay_buffer()
obj.pri = children[0]
return obj
tree_util.register_pytree_node(Replay_buffer, flatten_replay, unflatten_replay)
@jax.jit
def update(utd_ratio,replay_buffer):
utd_ratio=1+utd_ratio
replay_buffer.update_pri(1,2)
return utd_ratio, replay_buffer
replay_buffer = Replay_buffer()
s, replay_buffer = update(20,replay_buffer)
print(s)
# 21
print(replay_buffer.pri)
# 3 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Thank you very much for asking for help
Beta Was this translation helpful? Give feedback.
All reactions