Skip to content

Commit 3a09404

Browse files
author
jax authors
committed
Merge pull request #20586 from superbobry:jaxlib
PiperOrigin-RevId: 624941598
2 parents 78c056f + 754fab9 commit 3a09404

File tree

13 files changed

+25
-146
lines changed

13 files changed

+25
-146
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Remember to align the itemized text with the first line of an item within a list
2929
is deprecated; empty inputs to softmax are now supported without setting this.
3030
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
3131
now leads to an error rather than a warning.
32-
32+
* The minimum jaxlib version is now 0.4.23.
3333

3434
## jaxlib 0.4.27
3535

@@ -156,7 +156,7 @@ Remember to align the itemized text with the first line of an item within a list
156156
cannot interact, e.g., in arithmetic operations.
157157
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
158158
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
159-
The scope of a symbolic expression `e` can be read with `e.scope` and passed
159+
The scope of a symbolic expression `e` can be read with `e.scope` and passed
160160
into the above functions to direct them to construct symbolic expressions in
161161
a given scope.
162162
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.

jax/_src/dispatch.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from jax._src.interpreters import pxla
4545
from jax._src import lib
4646
from jax._src.lib import xla_client as xc
47-
from jax._src.lib import xla_extension_version
4847
from jax._src.monitoring import record_event_duration_secs
4948
from jax._src.partition_spec import PartitionSpec
5049
from jax._src.sharding import Sharding
@@ -82,15 +81,11 @@ def apply_primitive(prim, *args, **params):
8281
fun = xla_primitive_callable(prim, **params)
8382
# TODO(yashkatariya): Investigate adding is_primitive to jit and never
8483
# triggering the disable jit path instead of messing around with it here.
85-
if xla_extension_version >= 218:
86-
prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
87-
try:
88-
outs = fun(*args)
89-
finally:
90-
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
91-
else:
92-
with config.disable_jit(False):
93-
outs = fun(*args)
84+
prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
85+
try:
86+
outs = fun(*args)
87+
finally:
88+
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
9489
return outs
9590

9691
@util.cache()

jax/_src/interpreters/mlir.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from jax._src.layout import AutoLayout, DeviceLocalLayout
5050
from jax._src.lib import xla_client as xc
5151
from jax._src.lib import xla_extension
52-
from jax._src.lib import xla_extension_version
5352
from jax._src.lib.mlir import dialects
5453
from jax._src.lib.mlir import ir
5554
from jax._src.lib.mlir.dialects import func as func_dialect
@@ -911,7 +910,7 @@ def lower_jaxpr_to_module(
911910
"In multi-platform lowering either all or no lowering platforms "
912911
f"should support donation. Lowering for {platforms} of which "
913912
f"only {platforms_with_donation} support donation")
914-
if num_partitions > 1 and xla_extension_version >= 220 and (
913+
if num_partitions > 1 and (
915914
result_shardings is None or all(s is None for s in result_shardings)):
916915
xla_donated_args = donated_args
917916
if xla_donated_args is None:

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,12 +2918,8 @@ def from_hlo(name: str,
29182918
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
29192919
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
29202920

2921-
if xla_extension_version >= 217:
2922-
in_layouts, out_layouts = _get_layouts_from_executable(
2923-
xla_executable, in_layouts, out_layouts, len(ordered_effects))
2924-
else:
2925-
assert all(i is None for i in in_layouts)
2926-
assert all(o is None for o in out_layouts)
2921+
in_layouts, out_layouts = _get_layouts_from_executable(
2922+
xla_executable, in_layouts, out_layouts, len(ordered_effects))
29272923

29282924
out_shardings = maybe_recover_user_shardings(
29292925
in_shardings, out_shardings, global_in_avals, global_out_avals)

jax/_src/lax/fft.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from jax._src.interpreters import mlir
3131
from jax._src.lib.mlir.dialects import hlo
3232
from jax._src.lib import xla_client
33-
from jax._src.lib import xla_extension_version
34-
from jax._src.lib import ducc_fft
3533
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
3634

3735
__all__ = [
@@ -122,76 +120,6 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
122120
]
123121

124122

125-
def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
126-
x_aval, = ctx.avals_in
127-
128-
in_shape = x_aval.shape
129-
dtype = x_aval.dtype
130-
out_aval, = ctx.avals_out
131-
out_shape = out_aval.shape
132-
133-
forward = fft_type in (xla_client.FftType.FFT, xla_client.FftType.RFFT)
134-
ndims = len(in_shape)
135-
assert len(fft_lengths) >= 1
136-
assert len(fft_lengths) <= ndims, (fft_lengths, ndims)
137-
assert len(in_shape) == len(out_shape) == ndims
138-
139-
# PocketFft does not allow size 0 dimensions.
140-
if 0 in in_shape or 0 in out_shape:
141-
if fft_type == xla_client.FftType.RFFT:
142-
assert dtype in (np.float32, np.float64), dtype
143-
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
144-
145-
elif fft_type == xla_client.FftType.IRFFT:
146-
assert np.issubdtype(dtype, np.complexfloating), dtype
147-
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
148-
149-
else:
150-
assert np.issubdtype(dtype, np.complexfloating), dtype
151-
out_dtype = dtype
152-
153-
zero = mlir.ir_constant(np.array(0, dtype=out_dtype))
154-
return [
155-
mlir.broadcast_in_dim(ctx, zero, out_aval, broadcast_dimensions=[])]
156-
157-
strides_in = []
158-
stride = 1
159-
for d in reversed(in_shape):
160-
strides_in.append(stride)
161-
stride *= d
162-
strides_in = mlir.shape_tensor(
163-
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_in))))
164-
165-
strides_out = []
166-
stride = 1
167-
for d in reversed(out_shape):
168-
strides_out.append(stride)
169-
stride *= d
170-
strides_out = mlir.shape_tensor(
171-
mlir.eval_dynamic_shape(ctx, tuple(reversed(strides_out))))
172-
173-
# scale = 1. if forward else (1. / np.prod(fft_lengths)) as a f64[1] tensor
174-
double_type = mlir.ir.RankedTensorType.get((), mlir.ir.F64Type.get())
175-
size_fft_length_prod = np.prod(fft_lengths) if fft_lengths else 1
176-
size_fft_lengths, = mlir.eval_dynamic_shape_as_vals(ctx, (size_fft_length_prod,))
177-
size_fft_lengths = hlo.ConvertOp(double_type, size_fft_lengths)
178-
one = mlir.ir_constant(np.float64(1.))
179-
scale = one if forward else hlo.DivOp(one, size_fft_lengths)
180-
scale = hlo.ReshapeOp(
181-
mlir.ir.RankedTensorType.get((1,), mlir.ir.F64Type.get()),
182-
scale).result
183-
184-
in_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, in_shape))
185-
out_shape = mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, out_shape))
186-
in_shape = in_shape if fft_type != xla_client.FftType.IRFFT else out_shape
187-
188-
result_type = mlir.aval_to_ir_type(out_aval)
189-
return [ducc_fft.dynamic_ducc_fft_hlo(
190-
result_type, x,
191-
input_dtype=x_aval.dtype, ndims=ndims, input_shape=in_shape,
192-
strides_in=strides_in, strides_out=strides_out, scale=scale,
193-
fft_type=fft_type, fft_lengths=fft_lengths, result_shape=out_shape)]
194-
195123
def _naive_rfft(x, fft_lengths):
196124
y = fft(x, xla_client.FftType.FFT, fft_lengths)
197125
n = fft_lengths[-1]
@@ -253,8 +181,3 @@ def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
253181
mlir.register_lowering(fft_p, _fft_lowering)
254182
ad.deflinear2(fft_p, _fft_transpose_rule)
255183
batching.primitive_batchers[fft_p] = _fft_batching_rule
256-
257-
# TODO(phawkins): when jaxlib 0.4.21 is the minimum, use XLA's FFT lowering
258-
# always on CPU. At that point, we can also delete the DUCC FFT kernel from JAX.
259-
if xla_extension_version < 211:
260-
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')

jax/_src/xla_bridge.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ def _get_tpu_library_path() -> str | None:
129129

130130
libtpu_module = maybe_import_libtpu()
131131
if libtpu_module is not None:
132-
if xla_extension_version < 212:
133-
# xla_extension_version < 212 uses tpu_tracer which requires calling
134-
# configure_library_path.
135-
libtpu_module.configure_library_path()
136132
return libtpu_module.get_library_path()
137133

138134
return None
@@ -226,28 +222,17 @@ def register_backend_factory(name: str, factory: BackendFactory, *,
226222

227223

228224
def make_cpu_client() -> xla_client.Client:
229-
if xla_extension_version >= 223:
230-
collectives: xla_client._xla.CpuCollectives | None = None
231-
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
232-
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
233-
distributed_client=distributed.global_state.client,
234-
)
235-
return xla_client.make_cpu_client( # type: ignore
225+
collectives: xla_client._xla.CpuCollectives | None = None
226+
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
227+
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
236228
distributed_client=distributed.global_state.client,
237-
node_id=distributed.global_state.process_id,
238-
num_nodes=distributed.global_state.num_processes,
239-
collectives=collectives,
240229
)
241-
elif xla_extension_version >= 216:
242-
# TODO(phawkins): remove type: ignore after updating jaxlib version used for
243-
# mypy checks.
244-
return xla_client.make_cpu_client( # type: ignore
245-
distributed_client=distributed.global_state.client,
246-
node_id=distributed.global_state.process_id,
247-
num_nodes=distributed.global_state.num_processes,
248-
)
249-
else:
250-
return xla_client.make_cpu_client()
230+
return xla_client.make_cpu_client( # type: ignore
231+
distributed_client=distributed.global_state.client,
232+
node_id=distributed.global_state.process_id,
233+
num_nodes=distributed.global_state.num_processes,
234+
collectives=collectives,
235+
)
251236

252237

253238
register_backend_factory(

jax/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):
133133

134134

135135
__version__ = _get_version_string()
136-
_minimum_jaxlib_version = "0.4.20"
136+
_minimum_jaxlib_version = "0.4.23"
137137

138138
def _version_as_tuple(version_str):
139139
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

tests/api_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from jax._src.interpreters import partial_eval as pe
5959
from jax._src.lib import xla_client
6060
from jax._src.lib import xla_extension
61-
from jax._src.lib import xla_extension_version
6261
import jax._src.util as jax_util
6362
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
6463
import jax.custom_batching
@@ -4590,7 +4589,6 @@ def test_mesh_creation_error_message(self):
45904589
with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
45914590
jax.sharding.Mesh(jax.devices(), ("x", "y"))
45924591

4593-
@unittest.skipIf(xla_extension_version < 222, 'jaxlib version too old')
45944592
def test_jit_boundmethod_reference_cycle(self):
45954593
class A:
45964594
def __init__(self):

tests/array_interoperability_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def testJaxToNumpy(self, shape, dtype):
221221
x_np = np.from_dlpack(x_jax)
222222
self.assertAllClose(x_np, x_jax)
223223

224-
@unittest.skipIf(xla_extension_version < 221, "Requires newer jaxlib")
225224
def testNondefaultLayout(self):
226225
# Generate numpy array with nonstandard layout
227226
a = np.arange(4).reshape(2, 2)

tests/export_back_compat_test.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060

6161
from jax._src import config
6262
from jax._src import test_util as jtu
63-
from jax._src.lib import version as jaxlib_version
6463

6564
config.parse_flags_with_absl()
6665

@@ -142,23 +141,16 @@ def test_ducc_fft(self):
142141
def func(x):
143142
return lax.fft(x, fft_type="fft", fft_lengths=(4,))
144143

145-
# An old lowering, with ducc_fft. We keep it for 6 months.
146-
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
147-
if jaxlib_version <= (0, 4, 20):
148-
expect_current_custom_calls = ["dynamic_ducc_fft"]
149-
else:
150-
# We have changed the lowering for fft since we saved this data.
151-
# FFT no longer lowers to a custom call.
152-
expect_current_custom_calls = []
144+
# TODO(b/311175955): Remove this test and the corresponding custom calls.
153145

154-
self.run_one_test(func, data,
155-
expect_current_custom_calls=expect_current_custom_calls)
146+
# An old lowering, with ducc_fft.
147+
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
148+
self.run_one_test(func, data, expect_current_custom_calls=[])
156149

157150
# A newer lowering, with dynamic_ducc_fft.
158151
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
159152
# FFT no longer lowers to a custom call.
160-
self.run_one_test(func, data,
161-
expect_current_custom_calls=expect_current_custom_calls)
153+
self.run_one_test(func, data, expect_current_custom_calls=[])
162154

163155
def cholesky_input(self, shape, dtype):
164156
a = jtu.rand_default(self.rng())(shape, dtype)

0 commit comments

Comments
 (0)