Best practices for updating array parameters and buffer management #16904
-
Hello, I am currently working with a model and trying to understand the most efficient methods to manage and update arrays, specifically model parameters, in JAX without unnecessarily duplicating parameters on the device. In addition, I would like to gain a deeper understanding of the technical aspects behind these methods as I believe my current comprehension might be incomplete. For context, let's consider the following class: class Inference:
forward_fn: ... # assuming jitted
params: ...
def serve(self, inputs):
return self.forward_fn(inputs, self.params)
def set_params(self, new_params):
self.params = new_params My primary concern revolves around the However, this raises several questions and concerns:
I am curious if there's a way to manually update the buffer at our chosen time, taking into account that Any advice, best practices, or insights into how JAX handles these situations would be greatly appreciated. Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
In short, assuming your code is not wrapped in
The operation
Outside of JIT, there is no way to update the buffer pointed to by Hope that helps! |
Beta Was this translation helpful? Give feedback.
In short, assuming your code is not wrapped in
jit
, the way to think of it is like this,self.params
references some Python object that we might call obj – Python maintains arefcount
on obj that has a value of at least 1, becuseself.params
currently references it.self.params = new_params
, the refcount of obj is decreased by 1. Ifself.params
is the only reference to obj, then the refcount will become zero.