You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am encountering the error output "XlaRuntimeError: INTERNAL: Unable to serialize MPS module" when running functions containing jax.lax.cond. Here is a very simplified example, just to show:
When calling coords_at_t0 with any jax array input, I have the error output:
XlaRuntimeError Traceback (most recent call last)
Cell In[25], [line 1](vscode-notebook-cell:?execution_count=25&line=1)
----> [1](vscode-notebook-cell:?execution_count=25&line=1) coords_at_t0(jnp.ones(1))
Cell In[22], [line 9](vscode-notebook-cell:?execution_count=22&line=9)
[7](vscode-notebook-cell:?execution_count=22&line=7) def coords_at_t0(xt):
[8](vscode-notebook-cell:?execution_count=22&line=8) t = xt[-1]
----> [9](vscode-notebook-cell:?execution_count=22&line=9) return jax.lax.cond(t==0., true_fn, false_fn, operand=xt)
[... skipping hidden 17 frame]
File ~/test_venv/lib/python3.12/site-packages/jax/_src/compiler.py:303, in backend_compile(backend, module, options, host_callbacks)
[297] python3.12/site-packages/jax/_src/compiler.py:297) return backend.compile(
[298] python3.12/site-packages/jax/_src/compiler.py:298) built_c, compile_options=options, host_callbacks=host_callbacks
[299] python3.12/site-packages/jax/_src/compiler.py:299) )
[300] python3.12/site-packages/jax/_src/compiler.py:300) # Some backends don't have `host_callbacks` option yet
[301] python3.12/site-packages/jax/_src/compiler.py:301) # TODO(sharadmv): remove this fallback when all backends allow `compile`
[302] python3.12/site-packages/jax/_src/compiler.py:302) # to take in `host_callbacks`
--> [303] python3.12/site-packages/jax/_src/compiler.py:303) return backend.compile(built_c, compile_options=options)
[304] python3.12/site-packages/jax/_src/compiler.py:304) except xc.XlaRuntimeError as e:
[305] python3.12/site-packages/jax/_src/compiler.py:305) for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:
XlaRuntimeError: INTERNAL: Unable to serialize MPS module
But when I change the return line in true_fn like this (or with another basic jax function than cos):
return jnp.cos(xt) * 0.
Which sends back the same thing as "jnp.zeros_like(xt)", there is no error message, the code works and sends back correct results. Also, changing "jnp.zeros_like(xt)" to jnp.array([0.,...,0.])" (by making sure the list of zeros we give is of same length than xt) does nothing, the error message is the same.
I am on Macbook with Intel CPU and AMD Radeon Pro GPU, MacOS Sequoia 15.3.2 and I use:
jax==0.4.38
jax-metal==0.1.1
jaxlib==0.4.38
Anybody that has an idea about this or an alternative for jax.lax.cond is welcome! Thanks in advance!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am encountering the error output "XlaRuntimeError: INTERNAL: Unable to serialize MPS module" when running functions containing jax.lax.cond. Here is a very simplified example, just to show:
When calling coords_at_t0 with any jax array input, I have the error output:
But when I change the return line in true_fn like this (or with another basic jax function than cos):
Which sends back the same thing as "jnp.zeros_like(xt)", there is no error message, the code works and sends back correct results. Also, changing "jnp.zeros_like(xt)" to jnp.array([0.,...,0.])" (by making sure the list of zeros we give is of same length than xt) does nothing, the error message is the same.
I am on Macbook with Intel CPU and AMD Radeon Pro GPU, MacOS Sequoia 15.3.2 and I use:
Anybody that has an idea about this or an alternative for jax.lax.cond is welcome! Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions