Store results for further uses in the PyTree & jit/vmap context ? #17268
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments
-
Hi, thinking a bit I see a possible (simple) solution where the "_ws" is just initialized to an empty dictionnary in the init method. @register_pytree_node_class
class A(Base):
def __init__(
self,
beta,
scale,
trunc=0.0,
gsparams=None,
):
super().__init__(
beta=beta,
scale=scale,
trunc=trunc,
gsparams=gsparams,
)
# Create a workspace where functions can store some precomputed
# results
self._ws = {} #####
@property
def beta(self):
return self.params["beta"]
@property
def trunc(self):
return self.params["trunc"]
@property
def scale(self):
return self.params["scale"]
@property
def _y_trunc(self):
return self.scale * self.trunc**self.beta
@property
def _y_untrunc(self):
return self.scale**self.beta
@property
def _y(self):
if "_y" in self._ws: ####
print("use record")
return self._ws["_y"]
else:
res = jax.lax.select(self.trunc > 0, self._y_trunc, self._y_untrunc)
self._ws["_y"] = res
return res
def __hash__(self):
return hash(
(
"A",
self.beta,
self.scale,
self.trunc,
self.gsparams,
)
) But I was wandering if there is not a mechanism to put "property" like function in a cache? |
Beta Was this translation helpful? Give feedback.
0 replies
-
no interest? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Well, the text is certainly too long, sorry. Notably please look at the 2nd snippet after (nb. I keep the first snippet too to see what I was looking for)
I am asking the possibility to store (eventually long) computations in a sort of Work Space in the context of PyTree.
Here is a snippet that I share on Colab in this notebook
Here is the description of the use-case and the snippet.
I have first a base class that implement some commonalities as the tree_flatten/tree_unflatten functions, as well as the storage of two kinds of "parameters tracking" in the following we can forget the gsparams-like parameters allthough I let the code asis.
Then, I have an inherited class that is instantiated with some arguments that will serve to define the Base.params dicttionnary.
For a while let us forget the "ws" parameters
With this snippet it is possible to run
or
In a real context, however, the computation of variables such as "_y" can be cumbersome and one ask then why not to use a storage mechanism once the _y is called the first time as it depends only on parameters defined at class construction. Notice that these initialisation cannot be done in the init method as one will get problems with JAX when jit will be advocated.
It is the reason of the "ws" (aka work space) argument and I have put same "####" characters to point where it is used.
If one repeats the above python codes, and for instance print the obj2.params, and repeats the call to obj2._y then one will see after the first call, the message "use record" which signal that the "complicated code" is bypassed to return the result stored in the ws dico.
So far so good, but this is not the end of the story and motivate opening the discussion. In the code development, there are been set some tests, notably a "jitting" and a "vmapping" tests as followed (adapted for the present use case):
When the "ws" mechanism is not used at all and we recompute each time "_y" then both test return Ok (aka True).
But when the "ws" mechanism is used then the "vmap" test crash because of the duplicate function. It may exist a workaround to tune this duplicate function. But
My question is more on the use of the "Work Space" mechanism. Do you have any other mechanism that can fit our development philosophy? Thanks.
Beta Was this translation helpful? Give feedback.
All reactions