Question about XLA Compilation Error #118
Replies: 5 comments
-
Hi @hwang-warren what's the error? I don't see any tracebacks in your post |
Beta Was this translation helpful? Give feedback.
-
Sorry forgot to attach the error message. I retried with the official python script in Here's the reproducing procedure cd mujoco_playground/learning
python train_jax_ppo.py --env_name=CartpoleBalance Here's the success case. I still get quite a lot of warnings in the success case although the code was running, and the policy didn't converge Environment Config:
action_repeat: 1
ctrl_dt: 0.01
episode_length: 1000
sim_dt: 0.01
vision: false
vision_config:
enabled_geom_groups:
- 0
- 1
- 2
gpu_id: 0
history: 3
render_batch_size: 512
render_height: 64
render_width: 64
use_rasterizer: false
PPO Training Parameters:
action_repeat: 1
batch_size: 1024
discounting: 0.995
entropy_cost: 0.01
episode_length: 1000
learning_rate: 0.001
normalize_observations: true
num_envs: 2048
num_evals: 10
num_minibatches: 32
num_timesteps: 60000000
num_updates_per_batch: 16
reward_scaling: 10.0
unroll_length: 30
Experiment name: CartpoleBalance-20250423-224931
Logs are being stored in: /home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/logs/CartpoleBalance-20250423-224931
No checkpoint path provided, not restoring from checkpoint
Checkpoint path: /home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/logs/CartpoleBalance-20250423-224931/checkpoints
0: reward=317.351
W0423 22:49:59.314628 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
6881280: reward=543.822
W0423 22:50:05.698982 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
13762560: reward=636.353
W0423 22:50:12.066721 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
20643840: reward=991.090
W0423 22:50:18.439847 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
27525120: reward=995.524
W0423 22:50:24.813345 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
34406400: reward=998.832
W0423 22:50:31.210746 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
41287680: reward=951.536
W0423 22:50:37.588347 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
48168960: reward=519.757
W0423 22:50:43.959108 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
55050240: reward=432.634
W0423 22:50:50.339858 133555922883648 array_metadata_store.py:362] [process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.
61931520: reward=307.664
Done training.
Time to JIT compile: 12.046052464982495
Time to train: 66.97819621698
Starting inference...
FPS for rendering: 50.0
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 501/501 [00:00<00:00, 735.02it/s]
Rollout video saved as 'rollout.mp4'. rollout.mp4Here's the failure case, where the code crashes (rl_mujoco_playground) well_hwang@well2204:~/Documents/code/mujoco_playground_ws/mujoco_playground/learning$ python train_jax_ppo.py --env_name=CartpoleBalance
Environment Config:
action_repeat: 1
ctrl_dt: 0.01
episode_length: 1000
sim_dt: 0.01
vision: false
vision_config:
enabled_geom_groups:
- 0
- 1
- 2
gpu_id: 0
history: 3
render_batch_size: 512
render_height: 64
render_width: 64
use_rasterizer: false
PPO Training Parameters:
action_repeat: 1
batch_size: 1024
discounting: 0.995
entropy_cost: 0.01
episode_length: 1000
learning_rate: 0.001
normalize_observations: true
num_envs: 2048
num_evals: 10
num_minibatches: 32
num_timesteps: 60000000
num_updates_per_batch: 16
reward_scaling: 10.0
unroll_length: 30
Experiment name: CartpoleBalance-20250423-203518
Logs are being stored in: /home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/logs/CartpoleBalance-20250423-203518
No checkpoint path provided, not restoring from checkpoint
Checkpoint path: /home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/logs/CartpoleBalance-20250423-203518/checkpoints
0: reward=317.351
2025-04-23 20:35:34.618070: F external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1099] Non-OK-status: executable.status()
Status: INTERNAL: ptxas exited with non-zero error code 139, output: - Failure occured when compiling fusion gemm_fusion_dot.182 with config '{block_m:16,block_n:16,block_k:128,split_k:2,num_stages:4,num_warps:2,num_ctas:1}'
Fused HLO computation:
%gemm_fusion_dot.182_computation (parameter_0.23: f32[30,1024,256], parameter_1.23: f32[30,1024,256]) -> f32[256,256] {
%parameter_0.23 = f32[30,1024,256]{2,1,0} parameter(0)
%bitcast.20826 = f32[30720,256]{1,0} bitcast(%parameter_0.23)
%parameter_1.23 = f32[30,1024,256]{2,1,0} parameter(1)
%bitcast.20827 = f32[256,30720]{0,1} bitcast(%parameter_1.23)
ROOT %dot.327 = f32[256,256]{1,0} dot(%bitcast.20826, %bitcast.20827), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="pmap(training_epoch)/jit(main)/while/body/while/body/while/body/transpose(jvp(MLP))/hidden_1/transpose" source_file="/home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/train_jax_ppo.py" source_line=353}
}
Fatal Python error: Aborted
Thread 0x0000772bb4c5d440 (most recent call first):
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/compiler.py", line 324 in backend_compile
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/profiler.py", line 334 in wrapper
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/compiler.py", line 694 in _compile_and_write_cache
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/compiler.py", line 432 in compile_or_get_cached
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1101 in from_hlo
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 942 in compile
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/profiler.py", line 334 in wrapper
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 641 in parallel_callable
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/linear_util.py", line 477 in memoized_fun
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 353 in xla_pmap_impl_lazy
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/api.py", line 1634 in cache_miss
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 182 in reraise_with_filtered_traceback
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/brax/training/agents/ppo/train.py", line 575 in training_epoch_with_timing
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/brax/training/agents/ppo/train.py", line 693 in train
File "/home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/train_jax_ppo.py", line 353 in main
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/absl/app.py", line 261 in _run_main
File "/home/well_hwang/miniconda3/envs/rl_mujoco_playground/lib/python3.11/site-packages/absl/app.py", line 316 in run
File "/home/well_hwang/Documents/code/mujoco_playground_ws/mujoco_playground/learning/train_jax_ppo.py", line 423 in <module>
Extension modules: jaxlib.cpu_feature_guard, numpy._core._multiarray_umath, numpy.linalg._umath_linalg, msgpack._cmsgpack, yaml._yaml, _cffi_backend, scipy._lib._ccallback_c, charset_normalizer.md, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, PIL._imaging, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._cython_nnls, scipy._lib._uarray._uarray, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, psutil._psutil_linux, psutil._psutil_posix, lxml._elementpath, lxml.etree, scipy.ndimage._nd_image, scipy.ndimage._rank_filter_1d, _ni_label, scipy.ndimage._ni_label, google._upb._message, simplejson._speedups, kiwisolver._cext, requests.packages.charset_normalizer.md, requests.packages.chardet.md (total: 86)
Aborted (core dumped) |
Beta Was this translation helpful? Give feedback.
-
Hi @hwang-warren what's the difference in runs between the "failure" case and "success" case? The commands look the same to me. Is your dev environment different? |
Beta Was this translation helpful? Give feedback.
-
The environment and commands are exactly the same for success and failure cases. I configured my python env following the If I use And if I run the very top training code copied directly from the colab notebook, 100% it will fail, and the error message is almost the same. This is the reason why I'm raising this issue. |
Beta Was this translation helpful? Give feedback.
-
Got it OK. Given the error is a compilation error with ptxas and XLA, I would try to update JAX versions and see if the error reproduces. I would also make sure I'm not running multiple processes on the same GPU (if it's a memory issue), or trying a different GPU. I just ran this in the public colab on an A100 and things WAI. XLA compilation errors like this that we cannot reproduce are difficult for us to debug, and you may have better luck opening an issue with https://github.com/jax-ml/jax. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to reproduce the
locomotion.ipynb
in a.py
file.Here's the code I copied
I got the following XLA compilation error when the code run to the training part
Here's my
nvidia-smi
outputIt's worth noting that previously I tried running a simple MuJoCo GPU simulation using jit, XLA and it worked (see issue #102 )
Do you have any idea how to fix the error?
Beta Was this translation helpful? Give feedback.
All reactions