Skip to content

Best practice of type for a scalar value #18328

Answered by jakevdp
ToshiyukiBandai asked this question in Q&A
Discussion options

You must be logged in to vote

All of these should be fine. There are several differences:

  • outside a JAX-transformed function, python floats will lead to computations being executed on host rather than on device. For example, x = 1.0; y = 2.0; z = x * y is a host-side multiplication, not a multiplication on device.
  • In JAX's current implementation, both 1 and 2 will be considered weakly-typed floats, because their dtype is not specified. So if you do x * jnp.bfloat16(1.0), the result will be bfloat16 rather than promoting the bfloat32 to float32 or float64
  • The third option is strongly-typed under JAX's current implementation, because for list inputs JAX first converts to numpy array for efficiency. So if you do x * jnp…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@ToshiyukiBandai
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@ToshiyukiBandai
Comment options

Answer selected by ToshiyukiBandai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants