14
14
15
15
# NVIDIA CUDA kernels
16
16
17
- load ("@rules_python//python:defs.bzl" , "py_library" )
18
17
load (
19
18
"//jaxlib:jax.bzl" ,
20
19
"cuda_library" ,
21
20
"if_cuda_is_configured" ,
22
21
"pybind_extension" ,
23
22
)
23
+ load ("@rules_python//python:defs.bzl" , "py_library" )
24
24
25
25
licenses (["notice" ])
26
26
@@ -56,9 +56,9 @@ cc_library(
56
56
features = ["-use_header_modules" ],
57
57
deps = [
58
58
":cuda_vendor" ,
59
- "@tsl/ /tsl/cuda:cupti" ,
60
- "@tsl/ /tsl/cuda:cusolver" ,
61
- "@tsl/ /tsl/cuda:cusparse" ,
59
+ "@xla//xla /tsl/cuda:cupti" ,
60
+ "@xla//xla /tsl/cuda:cusolver" ,
61
+ "@xla//xla /tsl/cuda:cusparse" ,
62
62
"@com_google_absl//absl/base:core_headers" ,
63
63
"@com_google_absl//absl/log:check" ,
64
64
"@com_google_absl//absl/memory" ,
@@ -81,8 +81,8 @@ cc_library(
81
81
"//jaxlib:handle_pool" ,
82
82
"//jaxlib:kernel_helpers" ,
83
83
"@xla//xla/service:custom_call_status" ,
84
- "@tsl/ /tsl/cuda:cublas" ,
85
- "@tsl/ /tsl/cuda:cudart" ,
84
+ "@xla//xla /tsl/cuda:cublas" ,
85
+ "@xla//xla /tsl/cuda:cudart" ,
86
86
"@com_google_absl//absl/algorithm:container" ,
87
87
"@com_google_absl//absl/base" ,
88
88
"@com_google_absl//absl/base:core_headers" ,
@@ -117,7 +117,7 @@ pybind_extension(
117
117
":cublas_kernels" ,
118
118
":cuda_vendor" ,
119
119
"//jaxlib:kernel_nanobind_helpers" ,
120
- "@tsl/ /tsl/cuda:cublas" ,
120
+ "@xla//xla /tsl/cuda:cublas" ,
121
121
"@tsl//tsl/python/lib/core:numpy" ,
122
122
"@com_google_absl//absl/container:flat_hash_map" ,
123
123
"@com_google_absl//absl/strings:str_format" ,
@@ -135,8 +135,8 @@ cc_library(
135
135
"//jaxlib:handle_pool" ,
136
136
"//jaxlib:kernel_helpers" ,
137
137
"@xla//xla/service:custom_call_status" ,
138
- "@tsl/ /tsl/cuda:cudart" ,
139
- "@tsl/ /tsl/cuda:cudnn" ,
138
+ "@xla//xla /tsl/cuda:cudart" ,
139
+ "@xla//xla /tsl/cuda:cudnn" ,
140
140
"@com_google_absl//absl/status" ,
141
141
"@com_google_absl//absl/status:statusor" ,
142
142
"@com_google_absl//absl/strings:str_format" ,
@@ -174,7 +174,7 @@ cc_library(
174
174
"//jaxlib:handle_pool" ,
175
175
"//jaxlib:kernel_helpers" ,
176
176
"@xla//xla/service:custom_call_status" ,
177
- "@tsl/ /tsl/cuda:cusolver" ,
177
+ "@xla//xla /tsl/cuda:cusolver" ,
178
178
"@com_google_absl//absl/status" ,
179
179
"@com_google_absl//absl/status:statusor" ,
180
180
"@com_google_absl//absl/synchronization" ,
@@ -203,8 +203,8 @@ pybind_extension(
203
203
":cuda_vendor" ,
204
204
":cusolver_kernels" ,
205
205
"//jaxlib:kernel_nanobind_helpers" ,
206
- "@tsl/ /tsl/cuda:cudart" ,
207
- "@tsl/ /tsl/cuda:cusolver" ,
206
+ "@xla//xla /tsl/cuda:cudart" ,
207
+ "@xla//xla /tsl/cuda:cusolver" ,
208
208
"@tsl//tsl/python/lib/core:numpy" ,
209
209
"@com_google_absl//absl/container:flat_hash_map" ,
210
210
"@com_google_absl//absl/strings:str_format" ,
@@ -223,8 +223,8 @@ cc_library(
223
223
"//jaxlib:handle_pool" ,
224
224
"//jaxlib:kernel_helpers" ,
225
225
"@xla//xla/service:custom_call_status" ,
226
- "@tsl/ /tsl/cuda:cudart" ,
227
- "@tsl/ /tsl/cuda:cusparse" ,
226
+ "@xla//xla /tsl/cuda:cudart" ,
227
+ "@xla//xla /tsl/cuda:cusparse" ,
228
228
"@com_google_absl//absl/status" ,
229
229
"@com_google_absl//absl/status:statusor" ,
230
230
"@com_google_absl//absl/synchronization" ,
@@ -253,8 +253,8 @@ pybind_extension(
253
253
":cuda_vendor" ,
254
254
":cusparse_kernels" ,
255
255
"//jaxlib:kernel_nanobind_helpers" ,
256
- "@tsl/ /tsl/cuda:cudart" ,
257
- "@tsl/ /tsl/cuda:cusparse" ,
256
+ "@xla//xla /tsl/cuda:cudart" ,
257
+ "@xla//xla /tsl/cuda:cusparse" ,
258
258
"@tsl//tsl/python/lib/core:numpy" ,
259
259
"@com_google_absl//absl/algorithm:container" ,
260
260
"@com_google_absl//absl/base" ,
@@ -316,7 +316,7 @@ pybind_extension(
316
316
":cuda_lu_pivot_kernels_impl" ,
317
317
":cuda_vendor" ,
318
318
"//jaxlib:kernel_nanobind_helpers" ,
319
- "@tsl/ /tsl/cuda:cudart" ,
319
+ "@xla//xla /tsl/cuda:cudart" ,
320
320
"@local_config_cuda//cuda:cuda_headers" ,
321
321
"@nanobind" ,
322
322
],
@@ -366,7 +366,7 @@ pybind_extension(
366
366
":cuda_gpu_kernel_helpers" ,
367
367
":cuda_prng_kernels" ,
368
368
"//jaxlib:kernel_nanobind_helpers" ,
369
- "@tsl/ /tsl/cuda:cudart" ,
369
+ "@xla//xla /tsl/cuda:cudart" ,
370
370
"@local_config_cuda//cuda:cuda_headers" ,
371
371
"@nanobind" ,
372
372
],
@@ -400,7 +400,7 @@ cc_library(
400
400
"//jaxlib/gpu:triton_cc_proto" ,
401
401
"@xla//xla/service:custom_call_status" ,
402
402
"@xla//xla/stream_executor/gpu:asm_compiler" ,
403
- "@tsl/ /tsl/cuda:cudart" ,
403
+ "@xla//xla /tsl/cuda:cudart" ,
404
404
"@tsl//tsl/platform:env" ,
405
405
"@com_google_absl//absl/base:core_headers" ,
406
406
"@com_google_absl//absl/cleanup" ,
@@ -472,13 +472,13 @@ cc_library(
472
472
":cuda_vendor" ,
473
473
"//jaxlib:absl_status_casters" ,
474
474
"//jaxlib:kernel_nanobind_helpers" ,
475
- "@tsl/ /tsl/cuda:cublas" ,
476
- "@tsl/ /tsl/cuda:cudart" ,
477
- "@tsl/ /tsl/cuda:cudnn" ,
478
- "@tsl/ /tsl/cuda:cufft" ,
479
- "@tsl/ /tsl/cuda:cupti" ,
480
- "@tsl/ /tsl/cuda:cusolver" ,
481
- "@tsl/ /tsl/cuda:cusparse" ,
475
+ "@xla//xla /tsl/cuda:cublas" ,
476
+ "@xla//xla /tsl/cuda:cudart" ,
477
+ "@xla//xla /tsl/cuda:cudnn" ,
478
+ "@xla//xla /tsl/cuda:cufft" ,
479
+ "@xla//xla /tsl/cuda:cupti" ,
480
+ "@xla//xla /tsl/cuda:cusolver" ,
481
+ "@xla//xla /tsl/cuda:cusparse" ,
482
482
],
483
483
)
484
484
@@ -509,13 +509,13 @@ pybind_extension(
509
509
":versions_helpers" ,
510
510
"//jaxlib:absl_status_casters" ,
511
511
"//jaxlib:kernel_nanobind_helpers" ,
512
- "@tsl/ /tsl/cuda:cublas" ,
513
- "@tsl/ /tsl/cuda:cudart" ,
514
- "@tsl/ /tsl/cuda:cudnn" ,
515
- "@tsl/ /tsl/cuda:cufft" ,
516
- "@tsl/ /tsl/cuda:cupti" ,
517
- "@tsl/ /tsl/cuda:cusolver" ,
518
- "@tsl/ /tsl/cuda:cusparse" ,
512
+ "@xla//xla /tsl/cuda:cublas" ,
513
+ "@xla//xla /tsl/cuda:cudart" ,
514
+ "@xla//xla /tsl/cuda:cudnn" ,
515
+ "@xla//xla /tsl/cuda:cufft" ,
516
+ "@xla//xla /tsl/cuda:cupti" ,
517
+ "@xla//xla /tsl/cuda:cusolver" ,
518
+ "@xla//xla /tsl/cuda:cusparse" ,
519
519
"@com_google_absl//absl/status:statusor" ,
520
520
"@nanobind" ,
521
521
],
0 commit comments