Skip to content

Pallas/Mosaic support host input. #30102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class MemorySpace(enum.Enum):
ERROR = "error" # Memory space for checkify errors.
INDEX = "index" # Memory space for scalar prefetch arguments.
KEY = "key" # Memory space for PRNG keys.
HOST = "host" # Host memory space.

def __str__(self) -> str:
return self.value
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def __init__(
# Replace is a method, not a field.
replace = dataclasses.replace


class MemorySpace(enum.Enum):
ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY.
VMEM = "vmem"
SMEM = "smem"
CMEM = "cmem"
SEMAPHORE = "semaphore_mem"
HBM = "hbm"
HOST = "host"

def __str__(self) -> str:
return self.value
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None
case pallas_core.MemorySpace.ANY:
# Map the general ANY memory space to TPU ANY memory space
return TPUMemorySpace.ANY
case pallas_core.MemorySpace.HOST:
return TPUMemorySpace.HOST
case (
pallas_core.MemorySpace.ERROR
| pallas_core.MemorySpace.INDEX
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@


ANY = MemorySpace.ANY
HOST = MemorySpace.HOST
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
I32EnumAttrCase<"kHbm", 2, "hbm">,
I32EnumAttrCase<"kCmem", 3, "cmem">,
I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">,
I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">
I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">,
I32EnumAttrCase<"kHost", 6, "host">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
Expand Down
64 changes: 64 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,70 @@ def body(sem):
)(x)
np.testing.assert_array_equal(y, x)

def test_host_input_host_to_hbm_dma(self):
if self.INTERPRET:
self.skipTest('Interpret mode does not support host memory.')
if not jtu.if_cloud_tpu_at_least(2025, 7, 12):
self.skipTest("Requires libtpu built after 2025-07-12")
def kernel(x_host_ref, y_hbm_ref):
def body(sem):
pltpu.async_copy(x_host_ref, y_hbm_ref, sem).wait()

pl.run_scoped(body, pltpu.SemaphoreType.DMA)

x = jnp.arange(8 * 128.0).reshape((8, 128))
# Move input to the host.
x = jax.device_put(
x,
jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), 'x'),
jax.sharding.PartitionSpec(),
memory_kind='pinned_host',
),
)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.HOST),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_array_equal(y, x)

def test_host_input_hbm_to_host_dma(self):
if self.INTERPRET:
self.skipTest('Interpret mode does not support host memory.')
if not jtu.if_cloud_tpu_at_least(2025, 7, 12):
self.skipTest("Requires libtpu built after 2025-07-12")
def kernel(x_host_ref, y_hbm_ref, _):
def body(sem):
pltpu.async_copy(y_hbm_ref, x_host_ref, sem).wait()

pl.run_scoped(body, pltpu.SemaphoreType.DMA)

x = jnp.arange(8 * 128.0).reshape((8, 128))
y = jnp.ones((8, 128))
# Move input to the host.
x = jax.device_put(
x,
jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), 'x'),
jax.sharding.PartitionSpec(),
memory_kind='pinned_host',
),
)
z = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pl.HOST),
pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=pl.BlockSpec(memory_space=pl.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x, y)
np.testing.assert_array_equal(x, y)

def test_cannot_dma_with_nonscalar_semaphore_ref(self):
def kernel(x_hbm_ref, y_hbm_ref):
def body(sem):
Expand Down
Loading