Calculating the number of parameters of models developed in JAX in the right away #6153
-
Hi all. I am trying to calculate the number of parameters of different ViT (Vision Transformers) models from the official repository: https://github.com/google-research/vision_transformer. Here's how a model is loaded as per the official code: VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=num_classes)
_, params = VisionTransformer.init_by_shape(
jax.random.PRNGKey(0),
[(batch['image'].shape[1:], batch['image'].dtype.name)])
params = checkpoint.load_pretrained(
pretrained_path=f'{model}.npz',
init_params=params,
model_config=models.CONFIGS[model],
logger=logger,
) Here is the Colab Notebook where it is done end-to-end. An end-to-end example showing how to calculate the number of parameters here would be really helpful. |
Beta Was this translation helpful? Give feedback.
Answered by
tomhennigan
Mar 20, 2021
Replies: 1 comment 2 replies
-
The following should work: param_count = sum(x.size for x in jax.tree_leaves(params)) |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
sayakpaul
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The following should work: