Skip to content

Using jax.pmap inside jax.lax.scan for multicore computation of the loop #18477

Closed Answered by JaySandesara
JaySandesara asked this question in Q&A
Discussion options

You must be logged in to vote

Hello,

I think I was needlessly confusing myself. With pmap, I dont need to use lax.scan. The problem is resolved simply by pmapping the interpolation functions:

import os, sys, importlib, glob

import numpy as np
import pandas as pd
import math
pd.options.mode.chained_assignment = None
import matplotlib.pyplot as plt
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)

import jax
from jax import numpy as jnp
from functools import partial
from jax import pmap

jax.config.update('jax_platform_name', 'cpu')

platform = jax.lib.xla_bridge.get_backend().platform.casefold()
print("Platform: ", platform)

n_devi…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by JaySandesara
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant