Skip to content

Commit e557d4d

Browse files
SandSnip3rGoogle-ML-Automation
authored andcommitted
Pallas/Mosaic support host input.
PiperOrigin-RevId: 782171959
1 parent f926b3c commit e557d4d

File tree

6 files changed

+71
-2
lines changed

6 files changed

+71
-2
lines changed

jax/_src/pallas/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class MemorySpace(enum.Enum):
228228
ERROR = "error" # Memory space for checkify errors.
229229
INDEX = "index" # Memory space for scalar prefetch arguments.
230230
KEY = "key" # Memory space for PRNG keys.
231+
HOST = "host" # Host memory space.
231232

232233
def __str__(self) -> str:
233234
return self.value

jax/_src/pallas/mosaic/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def __init__(
143143
# Replace is a method, not a field.
144144
replace = dataclasses.replace
145145

146-
147146
class MemorySpace(enum.Enum):
148147
ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY.
149148
VMEM = "vmem"
150149
SMEM = "smem"
151150
CMEM = "cmem"
152151
SEMAPHORE = "semaphore_mem"
153152
HBM = "hbm"
153+
HOST = "host"
154154

155155
def __str__(self) -> str:
156156
return self.value

jax/_src/pallas/mosaic/lowering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None
241241
case pallas_core.MemorySpace.ANY:
242242
# Map the general ANY memory space to TPU ANY memory space
243243
return TPUMemorySpace.ANY
244+
case pallas_core.MemorySpace.HOST:
245+
return TPUMemorySpace.HOST
244246
case (
245247
pallas_core.MemorySpace.ERROR
246248
| pallas_core.MemorySpace.INDEX

jax/experimental/pallas/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@
7979

8080

8181
ANY = MemorySpace.ANY
82+
HOST = MemorySpace.HOST

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
163163
I32EnumAttrCase<"kHbm", 2, "hbm">,
164164
I32EnumAttrCase<"kCmem", 3, "cmem">,
165165
I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">,
166-
I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">
166+
I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">,
167+
I32EnumAttrCase<"kHost", 6, "host">
167168
]> {
168169
let genSpecializedAttr = 0;
169170
let cppNamespace = "::mlir::tpu";

tests/pallas/tpu_pallas_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,70 @@ def body(sem):
12381238
)(x)
12391239
np.testing.assert_array_equal(y, x)
12401240

1241+
def test_host_input_host_to_hbm_dma(self):
1242+
if self.INTERPRET:
1243+
self.skipTest('Interpret mode does not support host memory.')
1244+
if not jtu.if_cloud_tpu_at_least(2025, 7, 12):
1245+
self.skipTest("Requires libtpu built after 2025-07-12")
1246+
def kernel(x_host_ref, y_hbm_ref):
1247+
def body(sem):
1248+
pltpu.async_copy(x_host_ref, y_hbm_ref, sem).wait()
1249+
1250+
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
1251+
1252+
x = jnp.arange(8 * 128.0).reshape((8, 128))
1253+
# Move input to the host.
1254+
x = jax.device_put(
1255+
x,
1256+
jax.sharding.NamedSharding(
1257+
jax.sharding.Mesh(jax.devices(), 'x'),
1258+
jax.sharding.PartitionSpec(),
1259+
memory_kind='pinned_host',
1260+
),
1261+
)
1262+
y = self.pallas_call(
1263+
kernel,
1264+
in_specs=[
1265+
pl.BlockSpec(memory_space=pl.HOST),
1266+
],
1267+
out_specs=pl.BlockSpec(memory_space=pl.ANY),
1268+
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
1269+
)(x)
1270+
np.testing.assert_array_equal(y, x)
1271+
1272+
def test_host_input_hbm_to_host_dma(self):
1273+
if self.INTERPRET:
1274+
self.skipTest('Interpret mode does not support host memory.')
1275+
if not jtu.if_cloud_tpu_at_least(2025, 7, 12):
1276+
self.skipTest("Requires libtpu built after 2025-07-12")
1277+
def kernel(x_host_ref, y_hbm_ref, _):
1278+
def body(sem):
1279+
pltpu.async_copy(y_hbm_ref, x_host_ref, sem).wait()
1280+
1281+
pl.run_scoped(body, pltpu.SemaphoreType.DMA)
1282+
1283+
x = jnp.arange(8 * 128.0).reshape((8, 128))
1284+
y = jnp.ones((8, 128))
1285+
# Move input to the host.
1286+
x = jax.device_put(
1287+
x,
1288+
jax.sharding.NamedSharding(
1289+
jax.sharding.Mesh(jax.devices(), 'x'),
1290+
jax.sharding.PartitionSpec(),
1291+
memory_kind='pinned_host',
1292+
),
1293+
)
1294+
z = self.pallas_call(
1295+
kernel,
1296+
in_specs=[
1297+
pl.BlockSpec(memory_space=pl.HOST),
1298+
pl.BlockSpec(memory_space=pl.ANY),
1299+
],
1300+
out_specs=pl.BlockSpec(memory_space=pl.ANY),
1301+
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
1302+
)(x, y)
1303+
np.testing.assert_array_equal(x, y)
1304+
12411305
def test_cannot_dma_with_nonscalar_semaphore_ref(self):
12421306
def kernel(x_hbm_ref, y_hbm_ref):
12431307
def body(sem):

0 commit comments

Comments
 (0)