Error through converting a jax numpy pre-trained weight to h5 weight #14151
Replies: 1 comment 2 replies
-
Hi - thanks for the question! This actually has nothing to do with JAX: It looks like the problem you're facing comes from misuse of the So, for example, if you want to create your h5 dataset from the first array in the file, you could do something like this: jax_file = jnp.load('ViT-B_16_imagenet21k.npz')
arr = jax_file[jax_file.files[0]]
with h5py.File('ViT-B_16_imagenet21k.h5', 'w') as hf:
hf.create_dataset('weights', data=arr) You can read more about |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all.
I have downloaded a jax numpy weight file with npz suffix, but when I tried to convert it to h5 file I recieved this error:
the npz file will be downloaded in local path "~/.keras/weights/".
my error is:
my question is, How can I convert a jax file into h5 file, correctly?
Note1: the output of dir(jax_file):
Note2: type of jax_file is:
Note3: my tensorflow version 2.9.1
Any help will be appreciated.
Beta Was this translation helpful? Give feedback.
All reactions