Skip to content

Commit 498e81a

Browse files
superbobryjax authors
authored andcommitted
Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect. PiperOrigin-RevId: 621857046
1 parent 7227080 commit 498e81a

File tree

7 files changed

+19
-238
lines changed

7 files changed

+19
-238
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.27
1010

11+
* Deprecations & Removals
12+
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
13+
lowering pass via Triton Python APIs has been removed and the
14+
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
15+
16+
1117
## jaxlib 0.4.27
1218

1319
## jax 0.4.26 (April 3, 2024)

jax/_src/pallas/triton/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,5 @@ pytype_strict_library(
6767
"//jax:util",
6868
"//jax/_src/lib",
6969
"//jax/_src/pallas",
70-
# Users are expected to add a jax_triton dependency to use the legacy
71-
# lowering path.
7270
],
7371
)

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 1 addition & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -19,74 +19,17 @@
1919

2020
from __future__ import annotations
2121

22-
import dataclasses
2322
import io
2423
from typing import Any
25-
import zlib
2624

2725
import jax
2826
from jax import core as jax_core
29-
from jax._src import config
3027
from jax._src.interpreters import mlir
3128
from jax._src.lib import gpu_triton as triton_kernel_call_lib
3229
from jax._src.lib.mlir import ir
3330
from jax._src.pallas import core as pallas_core
3431
from jax._src.pallas.pallas_call import pallas_call_p
3532
from jax._src.pallas.triton import lowering
36-
from jax._src import util
37-
38-
39-
@dataclasses.dataclass
40-
class CompilationResult:
41-
kernel_name: str
42-
ttir: str
43-
ptx: str
44-
shared_mem_bytes: int
45-
compute_capability: int
46-
lowering_result: lowering.LoweringResult
47-
48-
49-
@util.weakref_lru_cache
50-
def compile_jaxpr(
51-
jaxpr: jax_core.Jaxpr,
52-
in_shapes,
53-
grid_mapping: pallas_core.GridMapping,
54-
name: str,
55-
num_warps: int,
56-
num_stages: int,
57-
debug: bool,
58-
) -> CompilationResult:
59-
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace # type: ignore
60-
import triton.backends.nvidia.compiler as cb # type: ignore
61-
62-
# TODO(sharadmv): handle multiple devices, right now we assume device 0
63-
# which is fine when we have multiple of the same GPU but this won't work in
64-
# general.
65-
device = 0
66-
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
67-
target = ("cuda", compute_capability)
68-
cuda_backend = cb.CUDABackend(target)
69-
cuda_options = cuda_backend.parse_options(
70-
dict(
71-
num_warps=num_warps,
72-
num_stages=num_stages,
73-
debug=debug,
74-
)
75-
)
76-
lowering_result = lowering.lower_jaxpr_to_triton_module(
77-
jaxpr, in_shapes, grid_mapping, name, cuda_options
78-
)
79-
80-
ttir = str(lowering_result.module)
81-
ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace(
82-
lowering_result.module,
83-
cuda_backend,
84-
cuda_options,
85-
compute_capability,
86-
)
87-
return CompilationResult(
88-
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
89-
)
9033

9134

9235
def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
@@ -101,81 +44,6 @@ def avals_to_layouts(avals):
10144
return [list(reversed(range(aval.ndim))) for aval in avals]
10245

10346

104-
def _pallas_call_ptx_lowering(
105-
ctx: mlir.LoweringRuleContext,
106-
*in_nodes,
107-
jaxpr: jax_core.Jaxpr,
108-
name: str,
109-
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
110-
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
111-
debug: bool,
112-
input_output_aliases: tuple[tuple[int, int], ...],
113-
grid_mapping: pallas_core.GridMapping,
114-
triton_params: dict[str, Any],
115-
num_warps: int,
116-
num_stages: int,
117-
):
118-
compilation_result = compile_jaxpr(
119-
jaxpr,
120-
(*in_shapes, *out_shapes),
121-
grid_mapping,
122-
name,
123-
num_warps,
124-
num_stages,
125-
debug=debug,
126-
)
127-
# Triton returns a tuple for ROCm. We just want file path to be passed
128-
if ctx.module_context.platforms[0] == 'rocm':
129-
compilation_result.ptx = compilation_result.ptx[1]
130-
131-
if debug:
132-
compilation_result.lowering_result.module.dump()
133-
134-
kernel = triton_kernel_call_lib.TritonKernel(
135-
compilation_result.kernel_name,
136-
num_warps,
137-
compilation_result.shared_mem_bytes,
138-
compilation_result.ptx,
139-
compilation_result.ttir,
140-
compilation_result.compute_capability,
141-
1,
142-
1,
143-
1, # TODO(giorgioa): Add support for clustering on H100s on Pallas.
144-
)
145-
146-
grid = normalize_grid(compilation_result.lowering_result.grid)
147-
148-
kernel_params = []
149-
for _ in range(len(in_shapes) + len(out_shapes)):
150-
kernel_params.append(
151-
triton_kernel_call_lib.create_array_parameter(
152-
0, # bytes to zero # TODO(cjfj): Expose through user API.
153-
16, # divisible by 16
154-
)
155-
)
156-
157-
kernel_call = triton_kernel_call_lib.TritonKernelCall(
158-
kernel, grid[0], grid[1], grid[2], kernel_params
159-
)
160-
161-
out_types = [
162-
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
163-
for shape in out_shapes
164-
]
165-
166-
serialized_metadata = triton_params.get("serialized_metadata", b"")
167-
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
168-
return mlir.custom_call(
169-
call_target_name="triton_kernel_call",
170-
result_types=out_types,
171-
operands=in_nodes,
172-
backend_config=zlib.compress(kernel_call_proto),
173-
operand_layouts=avals_to_layouts(ctx.avals_in),
174-
result_layouts=avals_to_layouts(ctx.avals_out),
175-
operand_output_aliases=dict(input_output_aliases),
176-
).results
177-
178-
17947
def _pallas_call_ttir_lowering(
18048
ctx: mlir.LoweringRuleContext,
18149
*in_nodes,
@@ -243,13 +111,6 @@ def _pallas_call_ttir_lowering(
243111
).results
244112

245113

246-
_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool(
247-
"jax_triton_compile_via_xla",
248-
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True),
249-
help="If True, Pallas delegates Triton kernel compilation to XLA.",
250-
)
251-
252-
253114
def pallas_call_lowering(
254115
ctx: mlir.LoweringRuleContext,
255116
*in_nodes,
@@ -298,12 +159,7 @@ def pallas_call_lowering(
298159
print(jaxpr)
299160
print(grid_mapping)
300161

301-
if _TRITON_COMPILE_VIA_XLA.value:
302-
lowering_fn = _pallas_call_ttir_lowering
303-
else:
304-
lowering_fn = _pallas_call_ptx_lowering
305-
306-
return lowering_fn(
162+
return _pallas_call_ttir_lowering(
307163
ctx,
308164
*in_nodes,
309165
jaxpr=jaxpr,

jax/experimental/pallas/gpu.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
# limitations under the License.
1414

1515
"""Contains Triton specific Pallas functions."""
16-
try:
17-
from jax._src.pallas import triton
18-
get_compute_capability = triton.get_compute_capability
19-
del triton
20-
except ImportError as e:
21-
raise ImportError("Cannot import Pallas Triton backend. "
22-
"Make sure you've installed jax-triton.") from e
16+
from jax._src.pallas import triton
17+
get_compute_capability = triton.get_compute_capability
18+
del triton

tests/pallas/BUILD

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ jax_test(
5656
"gpu_a100_x32",
5757
"gpu_h100_x32",
5858
],
59-
env = {
60-
"JAX_TRITON_COMPILE_VIA_XLA": "0",
61-
},
6259
shard_count = 4,
6360
deps = [
6461
"//jax:pallas_gpu",
@@ -101,41 +98,6 @@ jax_test(
10198
] + py_deps("absl/testing") + py_deps("numpy"),
10299
)
103100

104-
jax_test(
105-
name = "pallas_via_xla_test",
106-
srcs = [
107-
"pallas_test.py",
108-
],
109-
backend_tags = {
110-
"gpu": ["noasan"], # https://github.com/openai/triton/issues/2918
111-
},
112-
config_tags_overrides = {
113-
"gpu_a100_x32": {
114-
"ondemand": False, # Include in presubmit.
115-
},
116-
},
117-
disable_backends = [
118-
"cpu",
119-
"tpu",
120-
],
121-
disable_configs = [
122-
"gpu",
123-
"gpu_x32",
124-
"gpu_p100",
125-
"gpu_p100_x32",
126-
"gpu_a100",
127-
"gpu_h100",
128-
],
129-
enable_configs = [
130-
"gpu_a100_x32",
131-
"gpu_h100_x32",
132-
],
133-
shard_count = 4,
134-
deps = [
135-
"//jax:pallas_gpu",
136-
] + py_deps("absl/testing") + py_deps("numpy"),
137-
)
138-
139101
jax_test(
140102
name = "ops_test",
141103
srcs = [

tests/pallas/ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from jax.experimental import pallas as pl
2525
try:
2626
from jax.experimental.pallas import gpu as plgpu
27-
except (ModuleNotFoundError, ImportError):
27+
except ImportError:
2828
plgpu = None
2929
import jax.numpy as jnp
3030
import numpy as np

tests/pallas/pallas_test.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,21 @@
3333
from jax._src.lax.control_flow.for_loop import for_loop
3434
from jax._src.lib import version as jaxlib_version
3535
from jax._src.pallas.pallas_call import _trace_to_jaxpr
36+
if jaxlib_version >= (0, 4, 24):
37+
from jax._src.pallas.triton.lowering import LoweringError
38+
else:
39+
LoweringError = Exception
3640
from jax.interpreters import partial_eval as pe
3741
import jax.numpy as jnp
3842
from jax.experimental import pallas as pl
43+
try:
44+
from jax.experimental.pallas import gpu as plgpu
45+
except ImportError:
46+
plgpu = None
3947
from jax.experimental.pallas.ops import attention
4048
from jax.experimental.pallas.ops import layer_norm
4149
from jax.experimental.pallas.ops import rms_norm
4250
from jax.experimental.pallas.ops import softmax
43-
try:
44-
from jax._src.pallas.triton.lowering import LoweringError
45-
from jax._src.pallas.triton.pallas_call_registration import (
46-
compile_jaxpr,
47-
_TRITON_COMPILE_VIA_XLA,
48-
)
49-
from jax.experimental.pallas import gpu as plgpu
50-
except ModuleNotFoundError:
51-
LoweringError = Exception
52-
compile_jaxpr = None
53-
_TRITON_COMPILE_VIA_XLA = None
54-
plgpu = None
5551
import numpy as np
5652

5753

@@ -143,17 +139,7 @@ def setUp(self):
143139
not self.check_gpu_capability_at_least(80)):
144140
self.skipTest("Only works on GPUs with capability >= sm80")
145141

146-
try:
147-
import triton # noqa: F401
148-
except ImportError:
149-
if (
150-
_TRITON_COMPILE_VIA_XLA is not None
151-
and not _TRITON_COMPILE_VIA_XLA.value
152-
):
153-
self.skipTest("Triton is not installed.")
154142
super().setUp()
155-
if compile_jaxpr:
156-
compile_jaxpr.cache_clear()
157143
_trace_to_jaxpr.cache_clear()
158144

159145
def pallas_call(self, *args, **kwargs):
@@ -761,29 +747,6 @@ def f(x):
761747
self.assertEqual(f(x), 2.)
762748
self.assertEqual(trace_count, 1)
763749

764-
def test_pallas_compilation_cache(self):
765-
if not compile_jaxpr:
766-
self.skipTest("No Triton GPU.")
767-
if self.INTERPRET:
768-
raise unittest.SkipTest("No Triton compilation in interpreter mode.")
769-
if _TRITON_COMPILE_VIA_XLA.value:
770-
raise unittest.SkipTest("Triton is compiled via XLA.")
771-
772-
@functools.partial(
773-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
774-
grid=1)
775-
def add_one(x_ref, o_ref):
776-
o_ref[()] = x_ref[()] + 1.
777-
778-
@jax.jit
779-
def f(x):
780-
return add_one(add_one(x))
781-
782-
x = jnp.array(0., dtype=jnp.float32)
783-
self.assertEqual(f(x), 2.)
784-
num_misses = compile_jaxpr.cache_info().misses
785-
self.assertEqual(num_misses, 1)
786-
787750
@parameterized.parameters(*[
788751
(0, 0, 1),
789752
(0, 1, 1),

0 commit comments

Comments
 (0)