22
22
import jax .numpy as jnp
23
23
from jax .sharding import NamedSharding , PartitionSpec as P , SingleDeviceSharding
24
24
from jax ._src import config
25
- from jax ._src import layout
26
- from jax ._src .layout import Layout
25
+ from jax ._src .layout import Layout , DeviceLocalLayout as DLL
27
26
from jax ._src import test_util as jtu
28
27
from jax ._src .util import safe_zip
29
28
from jax ._src import xla_bridge
@@ -90,7 +89,7 @@ def init(x, y):
90
89
sds2 = jax .ShapeDtypeStruct (np_inp2 .shape , np_inp2 .dtype , sharding = s2 )
91
90
92
91
lowered_apply = jax .jit (apply ).lower (
93
- sds1 , sds2 , _in_layouts = layout .AUTO , _out_layouts = layout .AUTO )
92
+ sds1 , sds2 , _in_layouts = DLL .AUTO , _out_layouts = DLL .AUTO )
94
93
compiled_apply = lowered_apply .compile ()
95
94
96
95
arg_layouts , kw_layouts = compiled_apply ._input_layouts ()
@@ -159,8 +158,8 @@ def f(x):
159
158
self .assertArraysEqual (out , np_inp .T )
160
159
self .assertEqual (out .sharding , NamedSharding (mesh , P (None , 'y' , 'x' )))
161
160
162
- compiled_auto = jax .jit (f ).lower (sds , _in_layouts = layout .AUTO ,
163
- _out_layouts = layout .AUTO ).compile ()
161
+ compiled_auto = jax .jit (f ).lower (sds , _in_layouts = DLL .AUTO ,
162
+ _out_layouts = DLL .AUTO ).compile ()
164
163
self .assertTupleEqual (
165
164
extract_minor_to_major (compiled_auto ._input_layouts ()[0 ][0 ]), (2 , 1 , 0 ))
166
165
self .assertTupleEqual (
@@ -177,7 +176,7 @@ def f(x):
177
176
return x .T
178
177
179
178
compiled = jax .jit (f ).lower (
180
- arr , _in_layouts = None , _out_layouts = layout .AUTO ).compile ()
179
+ arr , _in_layouts = None , _out_layouts = DLL .AUTO ).compile ()
181
180
self .assertTupleEqual (
182
181
extract_minor_to_major (compiled ._input_layouts ()[0 ][0 ]), (1 , 0 ))
183
182
self .assertTupleEqual (
@@ -195,7 +194,7 @@ def test_sharding_and_layouts(self):
195
194
s = NamedSharding (mesh , P ('x' , 'y' ))
196
195
197
196
compiled = jax .jit (lambda x : x .T , in_shardings = s , out_shardings = s ).lower (
198
- np_inp , _in_layouts = layout .AUTO , _out_layouts = layout .AUTO ).compile ()
197
+ np_inp , _in_layouts = DLL .AUTO , _out_layouts = DLL .AUTO ).compile ()
199
198
out = compiled (np_inp )
200
199
self .assertTupleEqual (
201
200
extract_minor_to_major (compiled ._input_layouts ()[0 ][0 ]), (1 , 0 ))
@@ -210,8 +209,8 @@ def f(x, y, z, a, b, c):
210
209
211
210
shape = (8 , 2 )
212
211
inps = [np .arange (math .prod (shape )).reshape (shape )] * 6
213
- compiled = jax .jit (f ).lower (* inps , _in_layouts = layout .AUTO ,
214
- _out_layouts = layout .AUTO ).compile ()
212
+ compiled = jax .jit (f ).lower (* inps , _in_layouts = DLL .AUTO ,
213
+ _out_layouts = DLL .AUTO ).compile ()
215
214
arg_layouts , _ = compiled ._input_layouts ()
216
215
out1 , out2 = compiled (* inps )
217
216
@@ -244,10 +243,10 @@ def f(x):
244
243
with self .assertRaisesRegex (
245
244
ValueError ,
246
245
'Layout passed to jit does not match the layout on the respective arg' ):
247
- jax .jit (f ).lower (arr , _in_layouts = layout .AUTO )
246
+ jax .jit (f ).lower (arr , _in_layouts = DLL .AUTO )
248
247
249
248
compiled = jax .jit (f ).lower (
250
- sds , _in_layouts = layout .AUTO , _out_layouts = layout .AUTO ).compile ()
249
+ sds , _in_layouts = DLL .AUTO , _out_layouts = DLL .AUTO ).compile ()
251
250
252
251
with self .assertRaisesRegex (
253
252
ValueError ,
@@ -270,7 +269,7 @@ def test_device_put_concrete_layout(self):
270
269
arr = jax .device_put (np_inp , s )
271
270
272
271
compiled = jax .jit (
273
- lambda x : x * 2 ).lower (arr , _out_layouts = layout .AUTO ).compile ()
272
+ lambda x : x * 2 ).lower (arr , _out_layouts = DLL .AUTO ).compile ()
274
273
col = compiled ._output_layouts ()
275
274
276
275
out = jax .device_put (np_inp , col )
@@ -283,12 +282,12 @@ def test_device_put_concrete_layout(self):
283
282
def test_device_put_non_concrete_layout_error (self ):
284
283
np_inp = np .arange (16 ).reshape (8 , 2 )
285
284
286
- l1 = Layout (layout .AUTO , SingleDeviceSharding (jax .devices ()[0 ]))
285
+ l1 = Layout (DLL .AUTO , SingleDeviceSharding (jax .devices ()[0 ]))
287
286
with self .assertRaisesRegex (
288
287
ValueError , 'sharding and device_local_layout.*should be concrete' ):
289
288
jax .device_put (np_inp , l1 )
290
289
291
- l2 = Layout (layout .AUTO , None )
290
+ l2 = Layout (DLL .AUTO , None )
292
291
with self .assertRaisesRegex (
293
292
ValueError , 'sharding and device_local_layout.*should be concrete' ):
294
293
jax .device_put (np_inp , l2 )
0 commit comments