You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I created BiLSTM by applying LSTMCell in a "for loop", here LSTMCell is implemented based on JAX. When I apply LSTM on a sequence, I need process each element in the sequence in a "for loop", just like this:
seq_outputs = []
for t in range(seq_len):
out_forward = modulelist[0]([input_seq[:, t, :], (h[:num_layers, :, :], c[:num_layers, :, :])])
out_back = modulelist[1]([input_seq[:, seq_len - t - 1, :], (h[num_layers:, :, :], c[num_layers:, :, :])])
out = jnp.concatenate((out_forward, out_back), axis=-1)
seq_outputs.append(out)
modulelist[i] is a Sequential of multiple LSTMCell. Jit compilation takes a long time due to the seq_len is large. If I rewrite the code by using jax.lax.scan(), like this:
The error occurs: Value <__main__.Sequential object at 0x7f1c3c342340> with type <class '__main__.Sequential'> is not a valid JAX type
I know it's because in my own frameworks, Module is not a PyTree type, leading to Sequential and LSTMCell are not PyTree type(Sequential and LSTMCell inherit from Module ). I wonder how can I speeded the code without modify my Module? Appreciate!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I created BiLSTM by applying LSTMCell in a "for loop", here LSTMCell is implemented based on JAX. When I apply LSTM on a sequence, I need process each element in the sequence in a "for loop", just like this:
modulelist[i]
is aSequential
of multipleLSTMCell.
Jit compilation takes a long time due to theseq_len
is large. If I rewrite the code by using jax.lax.scan(), like this:The error occurs:
Value <__main__.Sequential object at 0x7f1c3c342340> with type <class '__main__.Sequential'> is not a valid JAX type
I know it's because in my own frameworks,
Module
is not aPyTree
type, leading toSequential
andLSTMCell
are notPyTree
type(Sequential
andLSTMCell
inherit fromModule
). I wonder how can I speeded the code without modify myModule
? Appreciate!Beta Was this translation helpful? Give feedback.
All reactions