-
Hello, I have a program that utilizes python decorators in conjunction with jax. I have a weird issue where only when running on GPU with a specific dataset (and the python decorators enabled) the process gets 'Killed'. When I run this crashing example on the CPU, or with jit disabled, or run other examples entirely I will see no issues. For whatever reason with this specific example running on GPU, it ramps my SSD and GPU usage up to 100% and holds it there consistently before the process is killed. When I turned on the debugging log level to compare between the examples it seems the XLA compilation is what is causing this crash. Is there any way better way I can debug the program to find out what the actual issue is here? Or do you think I should just update my jax version / cuda drivers and hope for the best? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Just to follow up on this. I seem to have found the issue. |
Beta Was this translation helpful? Give feedback.
Just to follow up on this. I seem to have found the issue.
jax.debug.breakpoint()
was materializing too many values for the debugger and somehow causing the process to be killed. In order to fix the issue, I usedjax.debug.breakpoint(num_frames=1)
to limit the number of frames to load into the debugger which seemed to stop the process crashes and significantly reduce the load on my machine when executing these breakpoints.