19
19
#include " absl/base/thread_annotations.h"
20
20
#include " absl/cleanup/cleanup.h"
21
21
#include " absl/container/flat_hash_map.h"
22
+ #include " absl/container/flat_hash_set.h"
22
23
#include " absl/log/check.h"
23
24
#include " absl/log/log.h"
24
25
#include " absl/status/status.h"
@@ -587,7 +588,13 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
587
588
// First run a single iteration of each to config to determine how many
588
589
// iterations to run for benchmarking.
589
590
float best = std::numeric_limits<float >::infinity ();
591
+ JAX_ASSIGN_OR_RETURN (gpuDevice_t device, GetStreamDevice (stream));
592
+ absl::flat_hash_set<Config*> configs_to_skip;
590
593
for (Config& config : kernel_call.configs_ ) {
594
+ if (!config.kernel_call .CanLaunchOnDevice (device)) {
595
+ configs_to_skip.insert (&config);
596
+ continue ;
597
+ }
591
598
JAX_ASSIGN_OR_RETURN (float t,
592
599
Benchmark (stream, config.kernel_call , buffers, 1 ));
593
600
LOG (INFO) << config.description << " , ran 1 iter in " << t << " ms" ;
@@ -605,9 +612,8 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
605
612
}
606
613
607
614
best = std::numeric_limits<float >::infinity ();
608
- JAX_ASSIGN_OR_RETURN (gpuDevice_t device, GetStreamDevice (stream));
609
615
for (Config& config : kernel_call.configs_ ) {
610
- if (!config. kernel_call . CanLaunchOnDevice (device )) {
616
+ if (configs_to_skip. contains (&config )) {
611
617
LOG (WARNING) << " Unable to launch autotune config on device: "
612
618
<< config.description ;
613
619
continue ;
0 commit comments