@@ -543,6 +543,24 @@ def async_copy(src_ref, dst_ref, sem):
543
543
544
544
def make_async_remote_copy (src_ref , dst_ref , send_sem , recv_sem , device_id ,
545
545
device_id_type : DeviceIdType = DeviceIdType .MESH ):
546
+ """Creates a description of a remote copy operation.
547
+
548
+ Copies data from src_ref on the current device to dst_ref on the device
549
+ specified by device_id. Both semaphores should be waited on using the
550
+ descriptor on both source and target devices.
551
+
552
+ Note that device_id can also refer to the current device.
553
+
554
+ Args:
555
+ src_ref: The source Reference.
556
+ dst_ref: The destination Reference.
557
+ send_sem: The semaphore on the source device.
558
+ recv_sem: The semaphore on the destination device.
559
+ device_id: The device id of the destination device.
560
+ device_id_type: The type of the device id.
561
+ Returns:
562
+ An AsyncCopyDescriptor.
563
+ """
546
564
src_ref , src_indexers = _get_ref_and_indexers (src_ref )
547
565
send_sem , send_sem_indexers = _get_ref_and_indexers (send_sem )
548
566
dst_ref , dst_indexers = _get_ref_and_indexers (dst_ref )
@@ -576,4 +594,24 @@ def _get_barrier_semaphore_abstract_eval():
576
594
)
577
595
578
596
def get_barrier_semaphore ():
597
+ """Returns a barrier semaphore.
598
+
599
+ This function returns a barrier semaphore based on the collective_id of the
600
+ current pallas kernel.
601
+
602
+ It's very important that the semaphore is wait-ed back down to 0, or else the
603
+ semaphores will become corrupted.
604
+
605
+ It's also very important that the collective_id is different for each pallas
606
+ kernel with communication. E.g. if you have two pallas kernels, one that syncs
607
+ across the X axis of the device mesh and the second that syncs across the Y
608
+ axis, they must have different collective_ids.
609
+ However it is legal for two kernels that perform the same synchronization
610
+ pattern (e.g. only communicating with neighbours on the same mesh axis)
611
+ to share a collective_id. However, if in doubt, prefer not sharing
612
+ collective_ids, as doing so incorrectly can lead to silent data corruption or
613
+ crashes.
614
+ Note that re-using the same collective_id doesn't guarantee that the same
615
+ semaphore is provided by XLA.
616
+ """
579
617
return get_barrier_semaphore_p .bind ()
0 commit comments