jnp.linalg.norm error #17325
-
I am trying to finetune wav2vec2 model along with bart as language model using jax/flax. but I got this error
jax version is |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
From the traceback, it seems the error is in line 1051 of File "run_flax_speech_recognition_seq2seq.py": "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])), The output of
|
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp thanks for reply. |
Beta Was this translation helpful? Give feedback.
Hi @jakevdp thanks for reply.
I tried this code and it worked for me.
"encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])[0]),