Skip to content

Commit 0e797a2

Browse files
epiquerasjax authors
authored andcommitted
[Pallas TPU Pipeline] Add support for arbitrary manual prefetches in pipeline callbacks + compute execution.
PiperOrigin-RevId: 610592093
1 parent 596756f commit 0e797a2

File tree

3 files changed

+144
-42
lines changed

3 files changed

+144
-42
lines changed

jax/_src/pallas/mosaic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from jax._src.pallas.mosaic.pipeline import emit_pipeline
2929
from jax._src.pallas.mosaic.pipeline import PipelineCallbackArgs
3030
from jax._src.pallas.mosaic.pipeline import PipelinePrefetchArgs
31+
from jax._src.pallas.mosaic.pipeline import ManualPrefetchArgs
3132
from jax._src.pallas.mosaic.primitives import DeviceIdType
3233
from jax._src.pallas.mosaic.primitives import async_copy
3334
from jax._src.pallas.mosaic.primitives import async_remote_copy

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 142 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,39 @@ def __call__(
408408
force_copy: Union[
409409
bool, tuple[Union[CondVal, Any], Union[CondVal, Any]]
410410
] = False,
411-
) -> PipelineArg[PipelineBuffers]:
411+
force_skip: Union[
412+
bool, tuple[Union[CondVal, Any], Union[CondVal, Any]]
413+
] = False,
414+
) -> tuple[PipelineBuffers, PipelineBuffers]:
415+
...
416+
417+
418+
@dataclasses.dataclass(frozen=True)
419+
class ManualPrefetchArgs:
420+
"""Args for pipeline prefetch."""
421+
422+
pipeline_specs: PipelineBlockSpecs
423+
pipeline_refs: PipelineRefs
424+
pipeline_allocations: PipelineAllocations
425+
pipeline_buffers: PipelineBuffers
426+
427+
428+
class StartManualPrefetch(Protocol):
429+
"""Starts manual prefetch.
430+
431+
Use force_copy if a spec's indices don't change from last to first grid
432+
indices and you still want to force a copy. This must be used in conjunction
433+
with the prologue's return value to force a wait.
434+
"""
435+
436+
def __call__(
437+
self,
438+
prefetch_args: ManualPrefetchArgs,
439+
*,
440+
indices: GridIndices,
441+
force_copy: Union[bool, Union[CondVal, Any]] = False,
442+
force_skip: Union[bool, Union[CondVal, Any]] = False,
443+
) -> PipelineBuffers:
412444
...
413445

414446

@@ -422,6 +454,8 @@ class PipelineCallbackArgs:
422454
pipeline_buffers: PipelineArg[PipelineBuffers]
423455
make_pipeline_refs: MakePipelineRefs
424456
start_pipeline_prefetch: StartPipelinePrefetch
457+
start_manual_prefetch: StartManualPrefetch
458+
run_manual_compute: Callable[[Callable[[], None]], None]
425459

426460

427461
PipelinePrologue = Callable[
@@ -437,6 +471,8 @@ class PipelineCallbackArgs:
437471
PipelineEpilogue = Callable[
438472
[PipelineCallbackArgs], tuple[PipelineBuffers, PipelineBuffers]
439473
]
474+
PipelineOutPrologue = Callable[[PipelineCallbackArgs], Union[CondVal, Any]]
475+
PipelineOutEpilogue = Callable[[PipelineCallbackArgs], Union[CondVal, Any]]
440476

441477

442478
class Pipeline(Protocol):
@@ -449,8 +485,8 @@ def __call__(
449485
init_allocations: CondVal = False,
450486
prologue: Union[PipelinePrologue, None] = None,
451487
epilogue: Union[PipelineEpilogue, None] = None,
452-
out_prologue: Union[PipelinePrologue, None] = None,
453-
out_epilogue: Union[PipelineEpilogue, None] = None,
488+
out_prologue: Union[PipelineOutPrologue, None] = None,
489+
out_epilogue: Union[PipelineOutEpilogue, None] = None,
454490
) -> None:
455491
...
456492

@@ -523,40 +559,81 @@ def start_pipeline_prefetch(
523559
force_copy: Union[
524560
bool, tuple[Union[CondVal, Any], Union[CondVal, Any]]
525561
] = False,
562+
force_skip: Union[
563+
bool, tuple[Union[CondVal, Any], Union[CondVal, Any]]
564+
] = False,
526565
) -> tuple[PipelineBuffers, PipelineBuffers]:
527566
if isinstance(force_copy, bool):
528-
next_in_and_in_out_buffers = tree_util.tree_map(
529-
partial(_start_block_copy_in, indices=indices, force_copy=force_copy),
530-
pipeline_specs.input_and_in_out,
531-
prefetch_args.pipeline_refs.input_and_in_out,
532-
prefetch_args.pipeline_allocations.input_and_in_out,
533-
prefetch_args.pipeline_buffers.input_and_in_out,
534-
)
535-
else:
536-
force_input_copy, force_in_out_copy = force_copy
537-
force_input_copy = _broadcast_pytree_to(
538-
"force_input_copy",
539-
force_input_copy,
540-
pipeline_specs.input,
541-
)
542-
force_in_out_copy = _broadcast_pytree_to(
543-
"force_in_out_copy",
544-
force_in_out_copy,
545-
pipeline_specs.out,
546-
)
547-
next_in_and_in_out_buffers = _tree_map_with_kwargs(
548-
partial(_start_block_copy_in, indices=indices),
549-
pipeline_specs.input_and_in_out,
550-
prefetch_args.pipeline_refs.input_and_in_out,
551-
prefetch_args.pipeline_allocations.input_and_in_out,
552-
prefetch_args.pipeline_buffers.input_and_in_out,
553-
force_copy=force_input_copy + force_in_out_copy,
554-
)
567+
force_copy = (force_copy, force_copy)
568+
if isinstance(force_skip, bool):
569+
force_skip = (force_skip, force_skip)
570+
force_input_copy, force_in_out_copy = force_copy
571+
force_input_copy = _broadcast_pytree_to(
572+
"force_input_copy",
573+
force_input_copy,
574+
pipeline_specs.input,
575+
)
576+
force_in_out_copy = _broadcast_pytree_to(
577+
"force_in_out_copy",
578+
force_in_out_copy,
579+
pipeline_specs.in_out,
580+
)
581+
force_input_skip, force_in_out_skip = force_skip
582+
force_input_skip = _broadcast_pytree_to(
583+
"force_input_skip",
584+
force_input_skip,
585+
pipeline_specs.input,
586+
)
587+
force_in_out_skip = _broadcast_pytree_to(
588+
"force_in_out_skip",
589+
force_in_out_skip,
590+
pipeline_specs.in_out,
591+
)
592+
next_in_and_in_out_buffers = _tree_map_with_kwargs(
593+
partial(_start_block_copy_in, indices=indices),
594+
pipeline_specs.input_and_in_out,
595+
prefetch_args.pipeline_refs.input_and_in_out,
596+
prefetch_args.pipeline_allocations.input_and_in_out,
597+
prefetch_args.pipeline_buffers.input_and_in_out,
598+
force_copy=force_input_copy + force_in_out_copy,
599+
force_skip=force_input_skip + force_in_out_skip,
600+
)
555601
next_in_buffers, next_in_out_buffers = split_list(
556602
next_in_and_in_out_buffers, [len(pipeline_specs.input)]
557603
)
558604
return next_in_buffers, next_in_out_buffers
559605

606+
def start_manual_prefetch(
607+
prefetch_args: ManualPrefetchArgs,
608+
*,
609+
indices: GridIndices,
610+
force_copy: Union[bool, Union[CondVal, Any]] = False,
611+
force_skip: Union[bool, Union[CondVal, Any]] = False,
612+
) -> PipelineBuffers:
613+
force_copy = _broadcast_pytree_to(
614+
"force_input_copy",
615+
force_copy,
616+
prefetch_args.pipeline_specs,
617+
)
618+
force_skip = _broadcast_pytree_to(
619+
"force_skip",
620+
force_skip,
621+
prefetch_args.pipeline_specs,
622+
)
623+
next_buffers = _tree_map_with_kwargs(
624+
partial(_start_block_copy_in, indices=indices),
625+
prefetch_args.pipeline_specs,
626+
prefetch_args.pipeline_refs,
627+
prefetch_args.pipeline_allocations,
628+
prefetch_args.pipeline_buffers,
629+
force_copy=force_copy,
630+
force_skip=force_skip,
631+
)
632+
return next_buffers
633+
634+
def run_manual_compute(fn: Callable[[], None]) -> None:
635+
fn()
636+
560637
def make_pipeline_allocations(
561638
*ref_args: PipelineRefs,
562639
return_treedef: bool = False,
@@ -623,8 +700,8 @@ def pipeline(
623700
init_allocations: CondVal = False,
624701
prologue: Union[PipelinePrologue, None] = None,
625702
epilogue: Union[PipelineEpilogue, None] = None,
626-
out_prologue: Union[PipelinePrologue, None] = None,
627-
out_epilogue: Union[PipelineEpilogue, None] = None,
703+
out_prologue: Union[PipelineOutPrologue, None] = None,
704+
out_epilogue: Union[PipelineOutEpilogue, None] = None,
628705
) -> None:
629706
use_in_out = jnp.logical_not(init_allocations)
630707
if scratchs is None:
@@ -683,16 +760,20 @@ def init_buffer_ref(_, buffer_ref):
683760
make_pipeline_refs=make_pipeline_refs,
684761
start_pipeline_prefetch=partial(
685762
cast(Any, start_pipeline_prefetch),
686-
indices=(zero_indices, last_indices, indices),
687-
force_copy=True,
763+
indices=(last_indices, zero_indices, indices),
764+
),
765+
start_manual_prefetch=partial(
766+
cast(Any, start_manual_prefetch),
767+
indices=(last_indices, zero_indices, indices),
688768
),
769+
run_manual_compute=run_manual_compute,
689770
)
690771
)
691772
else:
692-
skip_input_prologue = None
693-
skip_in_out_prologue = None
694-
force_input_prologue_wait = None
695-
force_in_out_prologue_wait = None
773+
skip_input_prologue = False
774+
skip_in_out_prologue = False
775+
force_input_prologue_wait = False
776+
force_in_out_prologue_wait = False
696777
skip_input_prologue = _broadcast_pytree_to(
697778
"skip_input_prologue",
698779
skip_input_prologue,
@@ -851,6 +932,11 @@ def run_epilogue():
851932
cast(Any, start_pipeline_prefetch),
852933
indices=(prev_indices, indices, zero_indices),
853934
),
935+
start_manual_prefetch=partial(
936+
cast(Any, start_manual_prefetch),
937+
indices=(prev_indices, indices, zero_indices),
938+
),
939+
run_manual_compute=run_manual_compute,
854940
)
855941
)
856942

@@ -961,8 +1047,12 @@ def run_out_prologue():
9611047
start_pipeline_prefetch=partial(
9621048
cast(Any, start_pipeline_prefetch),
9631049
indices=copy_indices,
964-
force_copy=True,
9651050
),
1051+
start_manual_prefetch=partial(
1052+
cast(Any, start_manual_prefetch),
1053+
indices=copy_indices,
1054+
),
1055+
run_manual_compute=run_manual_compute,
9661056
)
9671057
)
9681058
skip_out_prologue_wait = _broadcast_pytree_to(
@@ -1009,7 +1099,13 @@ def wait_prev_iteration_out_block_copies():
10091099
pipeline_allocations=pipeline_allocations,
10101100
pipeline_buffers=pipeline_buffers,
10111101
make_pipeline_refs=make_pipeline_refs,
1012-
start_pipeline_prefetch=cast(Any, lambda *args, **kwargs: None),
1102+
start_pipeline_prefetch=cast(
1103+
Any, lambda *args, **kwargs: None
1104+
),
1105+
start_manual_prefetch=cast(
1106+
Any, lambda *args, **kwargs: None
1107+
),
1108+
run_manual_compute=cast(Any, lambda *args, **kwargs: None),
10131109
)
10141110
)
10151111
else:
@@ -1084,9 +1180,13 @@ def set_buffer_ref(buffer_ref, buffer):
10841180
make_pipeline_refs=make_pipeline_refs,
10851181
start_pipeline_prefetch=partial(
10861182
cast(Any, start_pipeline_prefetch),
1087-
indices=(zero_indices, zero_indices, zero_indices),
1088-
force_copy=True,
1183+
indices=(prev_indices, indices, zero_indices),
1184+
),
1185+
start_manual_prefetch=partial(
1186+
cast(Any, start_manual_prefetch),
1187+
indices=(prev_indices, indices, zero_indices),
10891188
),
1189+
run_manual_compute=run_manual_compute,
10901190
)
10911191
)
10921192
else:

jax/experimental/pallas/tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src.pallas.mosaic import emit_pipeline
3131
from jax._src.pallas.mosaic import PipelineCallbackArgs
3232
from jax._src.pallas.mosaic import PipelinePrefetchArgs
33+
from jax._src.pallas.mosaic import ManualPrefetchArgs
3334
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
3435
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
3536
from jax._src.pallas.mosaic import get_barrier_semaphore

0 commit comments

Comments
 (0)