@@ -545,30 +545,15 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
545
545
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
546
546
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
547
547
548
- // If `stream` is in capture mode we can't run autotuning on it as we don't
549
- // want to capture it into a graph. We create a new stream to do autotuning
550
- // and destroy it when we are done.
548
+ // Autotuning is not supported if the the stream is in graph capture mode.
551
549
gpustreamCaptureStatus_t capture_status;
552
550
GPU_RETURN_IF_ERROR (gpuStreamIsCapturing (stream, &capture_status));
553
- bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;
554
-
555
- gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
556
- gpuStream_t autotune_stream = stream;
557
-
558
- // An event that synchronizes autotuning stream with a main one.
559
- gpuEvent_t autotune_event = nullptr ;
560
-
561
- if (is_capturing) {
562
- GPU_RETURN_IF_ERROR (gpuThreadExchangeStreamCaptureMode (&capture_mode));
563
-
564
- // Record event after completion of launched kernels on the main stream.
565
- GPU_RETURN_IF_ERROR (gpuEventCreate (&autotune_event, 0 ));
566
- GPU_RETURN_IF_ERROR (gpuEventRecord (autotune_event, stream));
567
-
568
- // Create a new stream to run autotuning and synchronize it with main sream.
569
- GPU_RETURN_IF_ERROR (
570
- gpuStreamCreate (&autotune_stream, GPU_STREAM_NON_BLOCKING));
571
- GPU_RETURN_IF_ERROR (gpuStreamWaitEvent (autotune_stream, autotune_event));
551
+ if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
552
+ return absl::FailedPreconditionError (
553
+ " Can't autotune Triton kernel when the stream is in graph capture "
554
+ " mode. Autotuning can rely on real data present in input buffers to "
555
+ " use them in address computation, but in graph capture mode buffers "
556
+ " can have arbitrary data" );
572
557
}
573
558
574
559
// If an input aliases with an output, it will get overwritten during the
@@ -581,8 +566,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
581
566
std::vector<uint8_t > input_copy (size);
582
567
GPU_RETURN_IF_ERROR (gpuMemcpyDtoHAsync (
583
568
input_copy.data (),
584
- reinterpret_cast <gpuDevicePtr_t>(buffers[input_idx]), size,
585
- autotune_stream));
569
+ reinterpret_cast <gpuDevicePtr_t>(buffers[input_idx]), size, stream));
586
570
input_copies[input_idx] = std::move (input_copy);
587
571
}
588
572
}
@@ -592,8 +576,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
592
576
// iterations to run for benchmarking.
593
577
float best = std::numeric_limits<float >::infinity ();
594
578
for (Config& config : kernel_call.configs_ ) {
595
- JAX_ASSIGN_OR_RETURN (
596
- float t, Benchmark (autotune_stream , config.kernel_call , buffers, 1 ));
579
+ JAX_ASSIGN_OR_RETURN (float t,
580
+ Benchmark (stream , config.kernel_call , buffers, 1 ));
597
581
LOG (INFO) << config.description << " , ran 1 iter in " << t << " ms" ;
598
582
best = std::min (best, t);
599
583
}
@@ -609,16 +593,16 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
609
593
}
610
594
611
595
best = std::numeric_limits<float >::infinity ();
612
- JAX_ASSIGN_OR_RETURN (gpuDevice_t device, GetStreamDevice (autotune_stream ));
596
+ JAX_ASSIGN_OR_RETURN (gpuDevice_t device, GetStreamDevice (stream ));
613
597
for (Config& config : kernel_call.configs_ ) {
614
598
if (!config.kernel_call .CanLaunchOnDevice (device)) {
615
599
LOG (WARNING) << " Unable to launch autotune config on device: "
616
600
<< config.description ;
617
601
continue ;
618
602
}
619
603
620
- JAX_ASSIGN_OR_RETURN (float t, Benchmark (autotune_stream, config. kernel_call ,
621
- buffers, timed_iters));
604
+ JAX_ASSIGN_OR_RETURN (
605
+ float t, Benchmark (stream, config. kernel_call , buffers, timed_iters));
622
606
LOG (INFO) << config.description << " , ran " << timed_iters << " iters in "
623
607
<< t << " ms" ;
624
608
@@ -639,25 +623,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
639
623
640
624
// Restore aliased inputs to their original values.
641
625
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_ ) {
642
- GPU_RETURN_IF_ERROR (gpuMemcpyHtoDAsync (
643
- reinterpret_cast <gpuDevicePtr_t>(buffers[input_idx]),
644
- input_copies[input_idx].data (), size, autotune_stream ));
626
+ GPU_RETURN_IF_ERROR (
627
+ gpuMemcpyHtoDAsync ( reinterpret_cast <gpuDevicePtr_t>(buffers[input_idx]),
628
+ input_copies[input_idx].data (), size, stream ));
645
629
}
646
630
647
631
// Synchronize stream to ensure copies are complete before the host copy
648
632
// is deleted.
649
- GPU_RETURN_IF_ERROR (gpuStreamSynchronize (autotune_stream));
650
-
651
- if (is_capturing) {
652
- // Wait on a main stream for completion of autotuning.
653
- GPU_RETURN_IF_ERROR (gpuEventRecord (autotune_event, autotune_stream));
654
- GPU_RETURN_IF_ERROR (gpuStreamWaitEvent (stream, autotune_event));
655
- GPU_RETURN_IF_ERROR (gpuEventDestroy (autotune_event));
656
-
657
- // Destroy autotuning stream and recover stream capturing mode.
658
- GPU_RETURN_IF_ERROR (gpuStreamDestroy (autotune_stream));
659
- GPU_RETURN_IF_ERROR (gpuThreadExchangeStreamCaptureMode (&capture_mode));
660
- }
633
+ GPU_RETURN_IF_ERROR (gpuStreamSynchronize (stream));
661
634
662
635
return std::move (kernel_call.configs_ [0 ].kernel_call );
663
636
}
0 commit comments