@@ -408,7 +408,39 @@ def __call__(
408
408
force_copy : Union [
409
409
bool , tuple [Union [CondVal , Any ], Union [CondVal , Any ]]
410
410
] = False ,
411
- ) -> PipelineArg [PipelineBuffers ]:
411
+ force_skip : Union [
412
+ bool , tuple [Union [CondVal , Any ], Union [CondVal , Any ]]
413
+ ] = False ,
414
+ ) -> tuple [PipelineBuffers , PipelineBuffers ]:
415
+ ...
416
+
417
+
418
+ @dataclasses .dataclass (frozen = True )
419
+ class ManualPrefetchArgs :
420
+ """Args for pipeline prefetch."""
421
+
422
+ pipeline_specs : PipelineBlockSpecs
423
+ pipeline_refs : PipelineRefs
424
+ pipeline_allocations : PipelineAllocations
425
+ pipeline_buffers : PipelineBuffers
426
+
427
+
428
+ class StartManualPrefetch (Protocol ):
429
+ """Starts manual prefetch.
430
+
431
+ Use force_copy if a spec's indices don't change from last to first grid
432
+ indices and you still want to force a copy. This must be used in conjunction
433
+ with the prologue's return value to force a wait.
434
+ """
435
+
436
+ def __call__ (
437
+ self ,
438
+ prefetch_args : ManualPrefetchArgs ,
439
+ * ,
440
+ indices : GridIndices ,
441
+ force_copy : Union [bool , Union [CondVal , Any ]] = False ,
442
+ force_skip : Union [bool , Union [CondVal , Any ]] = False ,
443
+ ) -> PipelineBuffers :
412
444
...
413
445
414
446
@@ -422,6 +454,8 @@ class PipelineCallbackArgs:
422
454
pipeline_buffers : PipelineArg [PipelineBuffers ]
423
455
make_pipeline_refs : MakePipelineRefs
424
456
start_pipeline_prefetch : StartPipelinePrefetch
457
+ start_manual_prefetch : StartManualPrefetch
458
+ run_manual_compute : Callable [[Callable [[], None ]], None ]
425
459
426
460
427
461
PipelinePrologue = Callable [
@@ -437,6 +471,8 @@ class PipelineCallbackArgs:
437
471
PipelineEpilogue = Callable [
438
472
[PipelineCallbackArgs ], tuple [PipelineBuffers , PipelineBuffers ]
439
473
]
474
+ PipelineOutPrologue = Callable [[PipelineCallbackArgs ], Union [CondVal , Any ]]
475
+ PipelineOutEpilogue = Callable [[PipelineCallbackArgs ], Union [CondVal , Any ]]
440
476
441
477
442
478
class Pipeline (Protocol ):
@@ -449,8 +485,8 @@ def __call__(
449
485
init_allocations : CondVal = False ,
450
486
prologue : Union [PipelinePrologue , None ] = None ,
451
487
epilogue : Union [PipelineEpilogue , None ] = None ,
452
- out_prologue : Union [PipelinePrologue , None ] = None ,
453
- out_epilogue : Union [PipelineEpilogue , None ] = None ,
488
+ out_prologue : Union [PipelineOutPrologue , None ] = None ,
489
+ out_epilogue : Union [PipelineOutEpilogue , None ] = None ,
454
490
) -> None :
455
491
...
456
492
@@ -523,40 +559,81 @@ def start_pipeline_prefetch(
523
559
force_copy : Union [
524
560
bool , tuple [Union [CondVal , Any ], Union [CondVal , Any ]]
525
561
] = False ,
562
+ force_skip : Union [
563
+ bool , tuple [Union [CondVal , Any ], Union [CondVal , Any ]]
564
+ ] = False ,
526
565
) -> tuple [PipelineBuffers , PipelineBuffers ]:
527
566
if isinstance (force_copy , bool ):
528
- next_in_and_in_out_buffers = tree_util .tree_map (
529
- partial (_start_block_copy_in , indices = indices , force_copy = force_copy ),
530
- pipeline_specs .input_and_in_out ,
531
- prefetch_args .pipeline_refs .input_and_in_out ,
532
- prefetch_args .pipeline_allocations .input_and_in_out ,
533
- prefetch_args .pipeline_buffers .input_and_in_out ,
534
- )
535
- else :
536
- force_input_copy , force_in_out_copy = force_copy
537
- force_input_copy = _broadcast_pytree_to (
538
- "force_input_copy" ,
539
- force_input_copy ,
540
- pipeline_specs .input ,
541
- )
542
- force_in_out_copy = _broadcast_pytree_to (
543
- "force_in_out_copy" ,
544
- force_in_out_copy ,
545
- pipeline_specs .out ,
546
- )
547
- next_in_and_in_out_buffers = _tree_map_with_kwargs (
548
- partial (_start_block_copy_in , indices = indices ),
549
- pipeline_specs .input_and_in_out ,
550
- prefetch_args .pipeline_refs .input_and_in_out ,
551
- prefetch_args .pipeline_allocations .input_and_in_out ,
552
- prefetch_args .pipeline_buffers .input_and_in_out ,
553
- force_copy = force_input_copy + force_in_out_copy ,
554
- )
567
+ force_copy = (force_copy , force_copy )
568
+ if isinstance (force_skip , bool ):
569
+ force_skip = (force_skip , force_skip )
570
+ force_input_copy , force_in_out_copy = force_copy
571
+ force_input_copy = _broadcast_pytree_to (
572
+ "force_input_copy" ,
573
+ force_input_copy ,
574
+ pipeline_specs .input ,
575
+ )
576
+ force_in_out_copy = _broadcast_pytree_to (
577
+ "force_in_out_copy" ,
578
+ force_in_out_copy ,
579
+ pipeline_specs .in_out ,
580
+ )
581
+ force_input_skip , force_in_out_skip = force_skip
582
+ force_input_skip = _broadcast_pytree_to (
583
+ "force_input_skip" ,
584
+ force_input_skip ,
585
+ pipeline_specs .input ,
586
+ )
587
+ force_in_out_skip = _broadcast_pytree_to (
588
+ "force_in_out_skip" ,
589
+ force_in_out_skip ,
590
+ pipeline_specs .in_out ,
591
+ )
592
+ next_in_and_in_out_buffers = _tree_map_with_kwargs (
593
+ partial (_start_block_copy_in , indices = indices ),
594
+ pipeline_specs .input_and_in_out ,
595
+ prefetch_args .pipeline_refs .input_and_in_out ,
596
+ prefetch_args .pipeline_allocations .input_and_in_out ,
597
+ prefetch_args .pipeline_buffers .input_and_in_out ,
598
+ force_copy = force_input_copy + force_in_out_copy ,
599
+ force_skip = force_input_skip + force_in_out_skip ,
600
+ )
555
601
next_in_buffers , next_in_out_buffers = split_list (
556
602
next_in_and_in_out_buffers , [len (pipeline_specs .input )]
557
603
)
558
604
return next_in_buffers , next_in_out_buffers
559
605
606
+ def start_manual_prefetch (
607
+ prefetch_args : ManualPrefetchArgs ,
608
+ * ,
609
+ indices : GridIndices ,
610
+ force_copy : Union [bool , Union [CondVal , Any ]] = False ,
611
+ force_skip : Union [bool , Union [CondVal , Any ]] = False ,
612
+ ) -> PipelineBuffers :
613
+ force_copy = _broadcast_pytree_to (
614
+ "force_input_copy" ,
615
+ force_copy ,
616
+ prefetch_args .pipeline_specs ,
617
+ )
618
+ force_skip = _broadcast_pytree_to (
619
+ "force_skip" ,
620
+ force_skip ,
621
+ prefetch_args .pipeline_specs ,
622
+ )
623
+ next_buffers = _tree_map_with_kwargs (
624
+ partial (_start_block_copy_in , indices = indices ),
625
+ prefetch_args .pipeline_specs ,
626
+ prefetch_args .pipeline_refs ,
627
+ prefetch_args .pipeline_allocations ,
628
+ prefetch_args .pipeline_buffers ,
629
+ force_copy = force_copy ,
630
+ force_skip = force_skip ,
631
+ )
632
+ return next_buffers
633
+
634
+ def run_manual_compute (fn : Callable [[], None ]) -> None :
635
+ fn ()
636
+
560
637
def make_pipeline_allocations (
561
638
* ref_args : PipelineRefs ,
562
639
return_treedef : bool = False ,
@@ -623,8 +700,8 @@ def pipeline(
623
700
init_allocations : CondVal = False ,
624
701
prologue : Union [PipelinePrologue , None ] = None ,
625
702
epilogue : Union [PipelineEpilogue , None ] = None ,
626
- out_prologue : Union [PipelinePrologue , None ] = None ,
627
- out_epilogue : Union [PipelineEpilogue , None ] = None ,
703
+ out_prologue : Union [PipelineOutPrologue , None ] = None ,
704
+ out_epilogue : Union [PipelineOutEpilogue , None ] = None ,
628
705
) -> None :
629
706
use_in_out = jnp .logical_not (init_allocations )
630
707
if scratchs is None :
@@ -683,16 +760,20 @@ def init_buffer_ref(_, buffer_ref):
683
760
make_pipeline_refs = make_pipeline_refs ,
684
761
start_pipeline_prefetch = partial (
685
762
cast (Any , start_pipeline_prefetch ),
686
- indices = (zero_indices , last_indices , indices ),
687
- force_copy = True ,
763
+ indices = (last_indices , zero_indices , indices ),
764
+ ),
765
+ start_manual_prefetch = partial (
766
+ cast (Any , start_manual_prefetch ),
767
+ indices = (last_indices , zero_indices , indices ),
688
768
),
769
+ run_manual_compute = run_manual_compute ,
689
770
)
690
771
)
691
772
else :
692
- skip_input_prologue = None
693
- skip_in_out_prologue = None
694
- force_input_prologue_wait = None
695
- force_in_out_prologue_wait = None
773
+ skip_input_prologue = False
774
+ skip_in_out_prologue = False
775
+ force_input_prologue_wait = False
776
+ force_in_out_prologue_wait = False
696
777
skip_input_prologue = _broadcast_pytree_to (
697
778
"skip_input_prologue" ,
698
779
skip_input_prologue ,
@@ -851,6 +932,11 @@ def run_epilogue():
851
932
cast (Any , start_pipeline_prefetch ),
852
933
indices = (prev_indices , indices , zero_indices ),
853
934
),
935
+ start_manual_prefetch = partial (
936
+ cast (Any , start_manual_prefetch ),
937
+ indices = (prev_indices , indices , zero_indices ),
938
+ ),
939
+ run_manual_compute = run_manual_compute ,
854
940
)
855
941
)
856
942
@@ -961,8 +1047,12 @@ def run_out_prologue():
961
1047
start_pipeline_prefetch = partial (
962
1048
cast (Any , start_pipeline_prefetch ),
963
1049
indices = copy_indices ,
964
- force_copy = True ,
965
1050
),
1051
+ start_manual_prefetch = partial (
1052
+ cast (Any , start_manual_prefetch ),
1053
+ indices = copy_indices ,
1054
+ ),
1055
+ run_manual_compute = run_manual_compute ,
966
1056
)
967
1057
)
968
1058
skip_out_prologue_wait = _broadcast_pytree_to (
@@ -1009,7 +1099,13 @@ def wait_prev_iteration_out_block_copies():
1009
1099
pipeline_allocations = pipeline_allocations ,
1010
1100
pipeline_buffers = pipeline_buffers ,
1011
1101
make_pipeline_refs = make_pipeline_refs ,
1012
- start_pipeline_prefetch = cast (Any , lambda * args , ** kwargs : None ),
1102
+ start_pipeline_prefetch = cast (
1103
+ Any , lambda * args , ** kwargs : None
1104
+ ),
1105
+ start_manual_prefetch = cast (
1106
+ Any , lambda * args , ** kwargs : None
1107
+ ),
1108
+ run_manual_compute = cast (Any , lambda * args , ** kwargs : None ),
1013
1109
)
1014
1110
)
1015
1111
else :
@@ -1084,9 +1180,13 @@ def set_buffer_ref(buffer_ref, buffer):
1084
1180
make_pipeline_refs = make_pipeline_refs ,
1085
1181
start_pipeline_prefetch = partial (
1086
1182
cast (Any , start_pipeline_prefetch ),
1087
- indices = (zero_indices , zero_indices , zero_indices ),
1088
- force_copy = True ,
1183
+ indices = (prev_indices , indices , zero_indices ),
1184
+ ),
1185
+ start_manual_prefetch = partial (
1186
+ cast (Any , start_manual_prefetch ),
1187
+ indices = (prev_indices , indices , zero_indices ),
1089
1188
),
1189
+ run_manual_compute = run_manual_compute ,
1090
1190
)
1091
1191
)
1092
1192
else :
0 commit comments