Skip to content

Commit 9809aa1

Browse files
beckerhejax authors
authored andcommitted
Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids: - a forward declaration of `GpuContext` - the `:asm_compiler_header` header only target The moved code is unchanged - I just move it from one file to another and fix up includes and dependencies. Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change. PiperOrigin-RevId: 623285804
1 parent a205c91 commit 9809aa1

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

jaxlib/cuda/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ cc_library(
399399
":triton_utils",
400400
"//jaxlib/gpu:triton_cc_proto",
401401
"@xla//xla/service:custom_call_status",
402-
"@xla//xla/stream_executor/gpu:asm_compiler",
402+
"@xla//xla/stream_executor/cuda:cuda_asm_compiler",
403403
"@xla//xla/tsl/cuda:cudart",
404404
"@tsl//tsl/platform:env",
405405
"@com_google_absl//absl/base:core_headers",

jaxlib/gpu/triton_kernels.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@
3030
#include "jaxlib/gpu/triton_utils.h"
3131
#include "jaxlib/gpu/vendor.h"
3232
#include "xla/service/custom_call_status.h"
33-
#include "xla/stream_executor/gpu/asm_compiler.h"
3433
#include "tsl/platform/env.h"
3534

35+
#ifdef JAX_GPU_CUDA
36+
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
37+
#endif
38+
3639
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
3740

3841
namespace jax::JAX_GPU_NAMESPACE {

0 commit comments

Comments
 (0)