Skip to content

Commit 1ae2022

Browse files
ezhulenevjax authors
authored andcommitted
[jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
1 parent 8e2a8b7 commit 1ae2022

File tree

1 file changed

+17
-44
lines changed

1 file changed

+17
-44
lines changed

jaxlib/gpu/triton_kernels.cc

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -545,30 +545,15 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
545545
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
546546
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };
547547

548-
// If `stream` is in capture mode we can't run autotuning on it as we don't
549-
// want to capture it into a graph. We create a new stream to do autotuning
550-
// and destroy it when we are done.
548+
// Autotuning is not supported if the the stream is in graph capture mode.
551549
gpustreamCaptureStatus_t capture_status;
552550
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
553-
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;
554-
555-
gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
556-
gpuStream_t autotune_stream = stream;
557-
558-
// An event that synchronizes autotuning stream with a main one.
559-
gpuEvent_t autotune_event = nullptr;
560-
561-
if (is_capturing) {
562-
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
563-
564-
// Record event after completion of launched kernels on the main stream.
565-
GPU_RETURN_IF_ERROR(gpuEventCreate(&autotune_event, 0));
566-
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, stream));
567-
568-
// Create a new stream to run autotuning and synchronize it with main sream.
569-
GPU_RETURN_IF_ERROR(
570-
gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
571-
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(autotune_stream, autotune_event));
551+
if (capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE) {
552+
return absl::FailedPreconditionError(
553+
"Can't autotune Triton kernel when the stream is in graph capture "
554+
"mode. Autotuning can rely on real data present in input buffers to "
555+
"use them in address computation, but in graph capture mode buffers "
556+
"can have arbitrary data");
572557
}
573558

574559
// If an input aliases with an output, it will get overwritten during the
@@ -581,8 +566,7 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
581566
std::vector<uint8_t> input_copy(size);
582567
GPU_RETURN_IF_ERROR(gpuMemcpyDtoHAsync(
583568
input_copy.data(),
584-
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size,
585-
autotune_stream));
569+
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]), size, stream));
586570
input_copies[input_idx] = std::move(input_copy);
587571
}
588572
}
@@ -592,8 +576,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
592576
// iterations to run for benchmarking.
593577
float best = std::numeric_limits<float>::infinity();
594578
for (Config& config : kernel_call.configs_) {
595-
JAX_ASSIGN_OR_RETURN(
596-
float t, Benchmark(autotune_stream, config.kernel_call, buffers, 1));
579+
JAX_ASSIGN_OR_RETURN(float t,
580+
Benchmark(stream, config.kernel_call, buffers, 1));
597581
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
598582
best = std::min(best, t);
599583
}
@@ -609,16 +593,16 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
609593
}
610594

611595
best = std::numeric_limits<float>::infinity();
612-
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(autotune_stream));
596+
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
613597
for (Config& config : kernel_call.configs_) {
614598
if (!config.kernel_call.CanLaunchOnDevice(device)) {
615599
LOG(WARNING) << "Unable to launch autotune config on device: "
616600
<< config.description;
617601
continue;
618602
}
619603

620-
JAX_ASSIGN_OR_RETURN(float t, Benchmark(autotune_stream, config.kernel_call,
621-
buffers, timed_iters));
604+
JAX_ASSIGN_OR_RETURN(
605+
float t, Benchmark(stream, config.kernel_call, buffers, timed_iters));
622606
LOG(INFO) << config.description << ", ran " << timed_iters << " iters in "
623607
<< t << " ms";
624608

@@ -639,25 +623,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
639623

640624
// Restore aliased inputs to their original values.
641625
for (auto [input_idx, _, size] : kernel_call.input_output_aliases_) {
642-
GPU_RETURN_IF_ERROR(gpuMemcpyHtoDAsync(
643-
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
644-
input_copies[input_idx].data(), size, autotune_stream));
626+
GPU_RETURN_IF_ERROR(
627+
gpuMemcpyHtoDAsync(reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
628+
input_copies[input_idx].data(), size, stream));
645629
}
646630

647631
// Synchronize stream to ensure copies are complete before the host copy
648632
// is deleted.
649-
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream));
650-
651-
if (is_capturing) {
652-
// Wait on a main stream for completion of autotuning.
653-
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, autotune_stream));
654-
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(stream, autotune_event));
655-
GPU_RETURN_IF_ERROR(gpuEventDestroy(autotune_event));
656-
657-
// Destroy autotuning stream and recover stream capturing mode.
658-
GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream));
659-
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
660-
}
633+
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(stream));
661634

662635
return std::move(kernel_call.configs_[0].kernel_call);
663636
}

0 commit comments

Comments
 (0)