@@ -4347,7 +4347,7 @@ def f(x, y):
4347
4347
input_shardings , _ = f .lower (inp , inp ).compile ().input_shardings
4348
4348
self .assertLen (input_shardings , 2 )
4349
4349
4350
- def test_aot_out_info (self ):
4350
+ def test_lowered_out_info (self ):
4351
4351
inp = np .arange (8 , dtype = np .int32 )
4352
4352
out_info = jax .jit (lambda x : x ).lower ((inp , inp )).out_info
4353
4353
self .assertEqual (out_info [0 ].shape , (8 ,))
@@ -4357,6 +4357,26 @@ def test_aot_out_info(self):
4357
4357
self .assertEqual (out_info [0 ].sharding , None )
4358
4358
self .assertEqual (out_info [1 ].sharding , None )
4359
4359
4360
+ def test_lowered_out_info_mesh (self ):
4361
+ mesh = jtu .create_mesh ((2 ,), 'x' )
4362
+ arr = jax .device_put (np .arange (8 , dtype = np .int32 ),
4363
+ NamedSharding (mesh , P ('x' )))
4364
+ lowered = jax .jit (lambda x : x * 2 ).lower (arr )
4365
+ out_info = lowered .out_info
4366
+ self .assertEqual (out_info .shape , (8 ,))
4367
+ self .assertEqual (out_info .dtype , np .int32 )
4368
+ self .assertEqual (out_info .sharding , None )
4369
+
4370
+ def test_compiled_out_info (self ):
4371
+ mesh = jtu .create_mesh ((2 ,), 'x' )
4372
+ arr = jax .device_put (np .arange (8 , dtype = np .int32 ),
4373
+ NamedSharding (mesh , P ('x' )))
4374
+ compiled = jax .jit (lambda x : x * 2 ).lower (arr ).compile ()
4375
+ out_info = compiled .out_info
4376
+ self .assertEqual (out_info .shape , (8 ,))
4377
+ self .assertEqual (out_info .dtype , np .int32 )
4378
+ self .assertEqual (out_info .sharding , NamedSharding (mesh , P ('x' )))
4379
+
4360
4380
def test_jit_trace (self ):
4361
4381
def f (x ):
4362
4382
return x * 2
0 commit comments