Skip to content

Performance regression while scaling up GPUs when computing deteriminants. #28675

Answered by dfm
Kieran-Loehr asked this question in Q&A
Discussion options

You must be logged in to vote

Good question! It looks like slogdet is backed by either a QR or LU decomposition. These decompositions are themselves backed by calls to cuSOLVER on NVIDIA GPUs. Until recently (JAX v0.5.3, I think), these library calls didn't support sharding like this out of the box. If you try your experiment with the latest version of JAX, I predict that you will see the scaling you expect. For older versions of JAX like the one you're using, the recommendation would be to use shard_map.

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
2 replies
@Kieran-Loehr
Comment options

@Kieran-Loehr
Comment options

Answer selected by Kieran-Loehr
Comment options

You must be logged in to vote
1 reply
@dfm
Comment options

dfm May 13, 2025
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants