Skip to content

Commit 90e9e47

Browse files
marvin-kimjax authors
authored andcommitted
[Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark. PiperOrigin-RevId: 626153377
1 parent de7d3b6 commit 90e9e47

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

jaxlib/cuda/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ cc_library(
406406
"@com_google_absl//absl/base:core_headers",
407407
"@com_google_absl//absl/cleanup",
408408
"@com_google_absl//absl/container:flat_hash_map",
409+
"@com_google_absl//absl/container:flat_hash_set",
409410
"@com_google_absl//absl/log",
410411
"@com_google_absl//absl/log:check",
411412
"@com_google_absl//absl/status",

jaxlib/gpu/triton_kernels.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "absl/base/thread_annotations.h"
2020
#include "absl/cleanup/cleanup.h"
2121
#include "absl/container/flat_hash_map.h"
22+
#include "absl/container/flat_hash_set.h"
2223
#include "absl/log/check.h"
2324
#include "absl/log/log.h"
2425
#include "absl/status/status.h"
@@ -587,7 +588,13 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
587588
// First run a single iteration of each to config to determine how many
588589
// iterations to run for benchmarking.
589590
float best = std::numeric_limits<float>::infinity();
591+
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
592+
absl::flat_hash_set<Config*> configs_to_skip;
590593
for (Config& config : kernel_call.configs_) {
594+
if (!config.kernel_call.CanLaunchOnDevice(device)) {
595+
configs_to_skip.insert(&config);
596+
continue;
597+
}
591598
JAX_ASSIGN_OR_RETURN(float t,
592599
Benchmark(stream, config.kernel_call, buffers, 1));
593600
LOG(INFO) << config.description << ", ran 1 iter in " << t << " ms";
@@ -605,9 +612,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
605612
}
606613

607614
best = std::numeric_limits<float>::infinity();
608-
JAX_ASSIGN_OR_RETURN(gpuDevice_t device, GetStreamDevice(stream));
609615
for (Config& config : kernel_call.configs_) {
610-
if (!config.kernel_call.CanLaunchOnDevice(device)) {
616+
if (configs_to_skip.contains(&config)) {
611617
LOG(WARNING) << "Unable to launch autotune config on device: "
612618
<< config.description;
613619
continue;

0 commit comments

Comments
 (0)