Skip to content

Commit 667a0c1

Browse files
author
jax authors
committed
Add some docstrings for remote DMAs and semaphore barriers.
PiperOrigin-RevId: 627037991
1 parent b79f3b7 commit 667a0c1

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

jax/_src/pallas/mosaic/primitives.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,24 @@ def async_copy(src_ref, dst_ref, sem):
543543

544544
def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
545545
device_id_type: DeviceIdType = DeviceIdType.MESH):
546+
"""Creates a description of a remote copy operation.
547+
548+
Copies data from src_ref on the current device to dst_ref on the device
549+
specified by device_id. Both semaphores should be waited on using the
550+
descriptor on both source and target devices.
551+
552+
Note that device_id can also refer to the current device.
553+
554+
Args:
555+
src_ref: The source Reference.
556+
dst_ref: The destination Reference.
557+
send_sem: The semaphore on the source device.
558+
recv_sem: The semaphore on the destination device.
559+
device_id: The device id of the destination device.
560+
device_id_type: The type of the device id.
561+
Returns:
562+
An AsyncCopyDescriptor.
563+
"""
546564
src_ref, src_indexers = _get_ref_and_indexers(src_ref)
547565
send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem)
548566
dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref)
@@ -576,4 +594,24 @@ def _get_barrier_semaphore_abstract_eval():
576594
)
577595

578596
def get_barrier_semaphore():
597+
"""Returns a barrier semaphore.
598+
599+
This function returns a barrier semaphore based on the collective_id of the
600+
current pallas kernel.
601+
602+
It's very important that the semaphore is wait-ed back down to 0, or else the
603+
semaphores will become corrupted.
604+
605+
It's also very important that the collective_id is different for each pallas
606+
kernel with communication. E.g. if you have two pallas kernels, one that syncs
607+
across the X axis of the device mesh and the second that syncs across the Y
608+
axis, they must have different collective_ids.
609+
However it is legal for two kernels that perform the same synchronization
610+
pattern (e.g. only communicating with neighbours on the same mesh axis)
611+
to share a collective_id. However, if in doubt, prefer not sharing
612+
collective_ids, as doing so incorrectly can lead to silent data corruption or
613+
crashes.
614+
Note that re-using the same collective_id doesn't guarantee that the same
615+
semaphore is provided by XLA.
616+
"""
579617
return get_barrier_semaphore_p.bind()

0 commit comments

Comments
 (0)