Skip to content

Calculating the number of parameters of models developed in JAX in the right away #6153

Answered by tomhennigan
sayakpaul asked this question in Q&A
Discussion options

You must be logged in to vote

The following should work:

param_count = sum(x.size for x in jax.tree_leaves(params))

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@sayakpaul
Comment options

@rsantet
Comment options

Answer selected by sayakpaul
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants