Skip to content

Commit d790c88

Browse files
yashk2810jax authors
authored andcommitted
Rename layout.AUTO to DeviceLocalLayout.AUTO
PiperOrigin-RevId: 621684185
1 parent 783d5d2 commit d790c88

File tree

3 files changed

+21
-23
lines changed

3 files changed

+21
-23
lines changed

jax/_src/layout.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@
2121
from jax._src.lib import xla_client as xc
2222

2323

24+
class AutoLayout:
25+
26+
def __repr__(self):
27+
return "AUTO"
28+
29+
2430
class DeviceLocalLayout:
2531
layout: xc.PjRtLayout
2632

33+
AUTO = AutoLayout()
34+
2735
def __init__(self, layout: xc.PjRtLayout):
2836
self._layout = layout
2937
self._layout_str = str(self._layout)
@@ -43,14 +51,6 @@ def _to_xla_layout(self) -> str:
4351
return self._layout_str
4452

4553

46-
class AutoLayout:
47-
48-
def __repr__(self):
49-
return "AUTO"
50-
51-
AUTO = AutoLayout()
52-
53-
5454
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
5555
ShardingOptions = Union[Sharding, None, AutoSharding]
5656

jax/experimental/layout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@
1414

1515
from jax._src.layout import (
1616
DeviceLocalLayout as DeviceLocalLayout,
17-
AUTO as AUTO,
1817
Layout as Layout
1918
)

tests/layout_test.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
2424
from jax._src import config
25-
from jax._src import layout
26-
from jax._src.layout import Layout
25+
from jax._src.layout import Layout, DeviceLocalLayout as DLL
2726
from jax._src import test_util as jtu
2827
from jax._src.util import safe_zip
2928
from jax._src import xla_bridge
@@ -90,7 +89,7 @@ def init(x, y):
9089
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
9190

9291
lowered_apply = jax.jit(apply).lower(
93-
sds1, sds2, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO)
92+
sds1, sds2, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO)
9493
compiled_apply = lowered_apply.compile()
9594

9695
arg_layouts, kw_layouts = compiled_apply._input_layouts()
@@ -159,8 +158,8 @@ def f(x):
159158
self.assertArraysEqual(out, np_inp.T)
160159
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
161160

162-
compiled_auto = jax.jit(f).lower(sds, _in_layouts=layout.AUTO,
163-
_out_layouts=layout.AUTO).compile()
161+
compiled_auto = jax.jit(f).lower(sds, _in_layouts=DLL.AUTO,
162+
_out_layouts=DLL.AUTO).compile()
164163
self.assertTupleEqual(
165164
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
166165
self.assertTupleEqual(
@@ -177,7 +176,7 @@ def f(x):
177176
return x.T
178177

179178
compiled = jax.jit(f).lower(
180-
arr, _in_layouts=None, _out_layouts=layout.AUTO).compile()
179+
arr, _in_layouts=None, _out_layouts=DLL.AUTO).compile()
181180
self.assertTupleEqual(
182181
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
183182
self.assertTupleEqual(
@@ -195,7 +194,7 @@ def test_sharding_and_layouts(self):
195194
s = NamedSharding(mesh, P('x', 'y'))
196195

197196
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
198-
np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
197+
np_inp, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
199198
out = compiled(np_inp)
200199
self.assertTupleEqual(
201200
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
@@ -210,8 +209,8 @@ def f(x, y, z, a, b, c):
210209

211210
shape = (8, 2)
212211
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
213-
compiled = jax.jit(f).lower(*inps, _in_layouts=layout.AUTO,
214-
_out_layouts=layout.AUTO).compile()
212+
compiled = jax.jit(f).lower(*inps, _in_layouts=DLL.AUTO,
213+
_out_layouts=DLL.AUTO).compile()
215214
arg_layouts, _ = compiled._input_layouts()
216215
out1, out2 = compiled(*inps)
217216

@@ -244,10 +243,10 @@ def f(x):
244243
with self.assertRaisesRegex(
245244
ValueError,
246245
'Layout passed to jit does not match the layout on the respective arg'):
247-
jax.jit(f).lower(arr, _in_layouts=layout.AUTO)
246+
jax.jit(f).lower(arr, _in_layouts=DLL.AUTO)
248247

249248
compiled = jax.jit(f).lower(
250-
sds, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
249+
sds, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
251250

252251
with self.assertRaisesRegex(
253252
ValueError,
@@ -270,7 +269,7 @@ def test_device_put_concrete_layout(self):
270269
arr = jax.device_put(np_inp, s)
271270

272271
compiled = jax.jit(
273-
lambda x: x * 2).lower(arr, _out_layouts=layout.AUTO).compile()
272+
lambda x: x * 2).lower(arr, _out_layouts=DLL.AUTO).compile()
274273
col = compiled._output_layouts()
275274

276275
out = jax.device_put(np_inp, col)
@@ -283,12 +282,12 @@ def test_device_put_concrete_layout(self):
283282
def test_device_put_non_concrete_layout_error(self):
284283
np_inp = np.arange(16).reshape(8, 2)
285284

286-
l1 = Layout(layout.AUTO, SingleDeviceSharding(jax.devices()[0]))
285+
l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0]))
287286
with self.assertRaisesRegex(
288287
ValueError, 'sharding and device_local_layout.*should be concrete'):
289288
jax.device_put(np_inp, l1)
290289

291-
l2 = Layout(layout.AUTO, None)
290+
l2 = Layout(DLL.AUTO, None)
292291
with self.assertRaisesRegex(
293292
ValueError, 'sharding and device_local_layout.*should be concrete'):
294293
jax.device_put(np_inp, l2)

0 commit comments

Comments
 (0)