Skip to content

Commit f74f4ed

Browse files
superbobryjax authors
authored andcommitted
Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable to :pallas_test. PiperOrigin-RevId: 621299158
1 parent a54eb81 commit f74f4ed

File tree

6 files changed

+30
-16
lines changed

6 files changed

+30
-16
lines changed

jax/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,15 @@ package_group(
104104
name = "pallas_gpu_users",
105105
packages = [
106106
"//...",
107+
"//learning/brain/research/jax",
107108
] + pallas_gpu_internal_users,
108109
)
109110

110111
package_group(
111112
name = "pallas_tpu_users",
112113
packages = [
113114
"//...",
115+
"//learning/brain/research/jax",
114116
] + pallas_tpu_internal_users,
115117
)
116118

jax/_src/pallas/pallas_call.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,17 +523,23 @@ def _pallas_call_lowering(
523523
if platform == "cpu":
524524
raise ValueError("Only interpret mode is supported on CPU backend.")
525525
elif platform == "cuda" or platform == "rocm":
526-
from jax._src.pallas.triton import pallas_call_registration # type: ignore
527-
528-
return pallas_call_registration.pallas_call_lowering(
529-
ctx, *in_nodes, interpret=interpret, **params
530-
)
526+
try:
527+
from jax._src.pallas.triton import pallas_call_registration # type: ignore
528+
except ImportError:
529+
pass
530+
else:
531+
return pallas_call_registration.pallas_call_lowering(
532+
ctx, *in_nodes, interpret=interpret, **params
533+
)
531534
elif platform == "tpu":
532-
from jax._src.pallas.mosaic import pallas_call_registration # type: ignore
533-
534-
return pallas_call_registration.pallas_call_tpu_lowering_rule(
535-
ctx, *in_nodes, interpret=interpret, **params
536-
)
535+
try:
536+
from jax._src.pallas.mosaic import pallas_call_registration # type: ignore
537+
except ImportError:
538+
pass
539+
else:
540+
return pallas_call_registration.pallas_call_tpu_lowering_rule(
541+
ctx, *in_nodes, interpret=interpret, **params
542+
)
537543

538544
raise _unsupported_lowering_error(platform)
539545

jax/_src/pallas/triton/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,7 @@ pytype_strict_library(
6767
"//jax:util",
6868
"//jax/_src/lib",
6969
"//jax/_src/pallas",
70-
] + py_deps("jax_triton"),
70+
# Users are expected to add a jax_triton dependency to use the legacy
71+
# lowering path.
72+
],
7173
)

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def compile_jaxpr(
5656
num_stages: int,
5757
debug: bool,
5858
) -> CompilationResult:
59-
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
60-
import triton.backends.nvidia.compiler as cb
59+
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace # type: ignore
60+
import triton.backends.nvidia.compiler as cb # type: ignore
6161

6262
# TODO(sharadmv): handle multiple devices, right now we assume device 0
6363
# which is fine when we have multiple of the same GPU but this won't work in

jax/experimental/jax2tf/tests/primitives_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def test_primitive_coverage(self):
182182
if p.name == "debug_callback":
183183
# TODO(sharadmv,necula): enable debug callbacks in TF
184184
continue
185+
if p.name == "pallas_call":
186+
continue
185187
if p.name in tf_not_yet_impl:
186188
self.assertNotIn(
187189
p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl

tests/pallas/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ jax_test(
5656
"gpu_a100_x32",
5757
"gpu_h100_x32",
5858
],
59+
env = {
60+
"JAX_TRITON_COMPILE_VIA_XLA": "0",
61+
},
5962
shard_count = 4,
6063
deps = [
6164
"//jax:pallas_gpu",
62-
] + py_deps("absl/testing") + py_deps("numpy"),
65+
] + py_deps("absl/testing") + py_deps("jax_triton") + py_deps("numpy"),
6366
)
6467

6568
jax_test(
@@ -162,8 +165,7 @@ jax_test(
162165
"gpu_h100_x32",
163166
],
164167
deps = [
165-
"//jax:pallas_gpu",
166-
"//jax:pallas_tpu",
168+
"//jax:pallas",
167169
] + py_deps("absl/testing") + py_deps("numpy"),
168170
)
169171

0 commit comments

Comments
 (0)