Best practices for design patterns #20497
Replies: 2 comments 14 replies
-
I think best practices for NN APIs are still an active area of work, which is why you find different projects making different choices. Maybe in a few years the dust will have settled and there will be an answer to your question. |
Beta Was this translation helpful? Give feedback.
-
As Jake stated there is no consensus yet. Flax is the most widely used by both Google and HugginFace (apart from Keras) so it would be the safer option. Since you are looking for "something in JAX that closely match the flow in TF/PyTorch" take a look at Flax's experimental NNX API. It represents Modules as "PyGraphs" and is able to express changes to the graph structure through jitted functions so model code tends to looks very similar to PyTorch/Keras: from flax.experimental import nnx
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs() # get a unique random key
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,))) # initialize parameters
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w.value + self.b.value
rngs = nnx.Rngs(0) # explicit RNG handling
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model
x = jnp.empty((1, 2)) # generate random data
y = model(x) # forward pass Stateful updates to Module in a jitted function are written as regular stateful python code: @nnx.jit
def add_one_to_bias(model: Linear):
model.b.value += 1 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have dabbled with a lot of libraries and frameworks for NNs in the past. Given the pace of progress in ML, one library/framework isn't enough that suits all the needs. Some are mature, some have got the syntax and the design right, some have the momentum from the ecosystem. One of the most frustrating experience in all this is the design inconsistency between the frameworks/libraries
and the fragmentation within the ecosystem of a certain framework/library.
We all have been trying to develop something in JAX that closely match the flow in TF/PyTorch but also allows us to use the goodies that comes with the functional programming paradigm. Looking closely, we can split the design patters of all frameworks/libraries that are built on top of JAX into three categories:
init-apply
pattern: Flax, Haiku, etc.And there are other ways as well to build something on top of JAX, and it is extremely hard for anyone to figure out the best design pattern without trying them for a lot of uses or building something simple from scratch. What is the recommended design pattern if we have to develop something minimal (limited to NNs only), especially when we have to deal with both stateful and stateless operations?
Beta Was this translation helpful? Give feedback.
All reactions