-
Hello, I am currently trying to use a jax-written function inside a multiprocessing.Pool parallezation scheme and I am facing Below follows a small working example of the workflow that I am trying to implement, with the job function [1] representing my function that should only "launch" 1 thread. When I run this script, htop reports an increase of ~40 threads for the Pool.map case and ~20 for the serial case. I can also clearly see high resource usage across all CPU cores of my machine. I have followed the recommendations of issue #743 for the XLA_FLAGS, OPENBLAS_NUM_THREADS, MKL_NUM_THREADS, and OMP_NUM_THREAD to no avail. Thus, I can synthesize my questions to two:
[1] - The (real) job function will be applied to each line of a (N, M) matrix, where N will be around 100. [2] - I haven't been able to confirm if I can use vmap/shard_map to do so, as I will have to also "mask" zeros that exist in each row and I haven't been able to understand if I can do that in this configuration Code: from multiprocessing import get_context
import os
# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
)
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREAD"] = "1"
import jax
import time
def timer_decorator(func):
def inner_function(*args, **kwargs):
start_wall = time.perf_counter()
start_cpu = time.process_time()
out = func(*args, **kwargs)
end_wall = time.perf_counter()
end_cpu = time.process_time()
wall_time = end_wall - start_wall
cpu_time = end_cpu - start_cpu
cpu_count = os.cpu_count()
print(
f"\tcpu usage {cpu_time/wall_time:.1f}/{cpu_count} wall_time:{wall_time:.1f}s"
)
return out
return inner_function
def job(random_seed: int):
"""Jax-compatible work"""
A = jax.random.normal(jax.random.PRNGKey(random_seed), (500, 500))
return A @ A @ A @ A
@timer_decorator
def forkserver(xx_vals):
print("Starting multiprocess with forkserver")
with get_context("forkserver").Pool(processes=2) as pool:
out = pool.map(job, xx_vals)
return out
@timer_decorator
def serial_run(xx_vals):
print("Starting serial for loop")
out = []
for i in xx_vals:
out.append(job(i))
return out
if __name__ == "__main__":
xx_vals = list(range(2000))
forkserver(xx_vals)
serial_run(xx_vals) Outputs: Starting multiprocess with forkserver
cpu usage 0.5/8 wall_time:15.8s
Starting serial for loop
cpu usage 4.9/8 wall_time:13.5s |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I don't think I'm able to reproduce this locally. Two questions:
os.environ["XLA_FLAGS"] = (
- "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
+ "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1"
) |
Beta Was this translation helpful? Give feedback.
Stepping back, what are you looking to do? If you are looking to guarantee that jax (and the full stack below it) will use exactly one thread, this is simply not possible. No matter what, jax (etc.) will use threads. We can try to restrict the number of threads in places where configuration allows it, but it won't restrict the entire system to exactly one.
The flags we've discussed so far are the configuration points that I know of, and so we're correctly restricting things where we can. It may be worthwhile to also try setting the environment variable
NPROC=1
as well.