@@ -3881,6 +3881,26 @@ def f():
3881
3881
lowered_text = make_keys .lower (seeds ).as_text ()
3882
3882
self .assertIn ('unspecified_dims=[0,1]' , lowered_text )
3883
3883
3884
+ def test_jit_partially_specified_shardings (self ):
3885
+ mesh = jtu .create_global_mesh ((2 , 2 ), ('x' , 'y' ))
3886
+ np_inp = np .arange (16 ).reshape (8 , 2 )
3887
+ s = NamedSharding (mesh , P ('x' , 'y' ))
3888
+ s2 = NamedSharding (mesh , P ('x' ))
3889
+ arr = jax .device_put (np_inp , s )
3890
+ arr2 = jax .device_put (np_inp , s2 )
3891
+
3892
+ @partial (jax .jit , in_shardings = (s , None , s2 , UNSPECIFIED , UNSPECIFIED ),
3893
+ out_shardings = (s2 , None , None , s , None ))
3894
+ def f (x , y , z , a , b ):
3895
+ return x * 2 , y @ y .T , z ** 2 , a * 3 , b .T
3896
+
3897
+ out1 , out2 , out3 , out4 , out5 = f (arr , np_inp , arr2 , np_inp , arr )
3898
+ self .assertArraysEqual (out1 , np_inp * 2 )
3899
+ self .assertArraysEqual (out2 , np_inp @ np_inp .T )
3900
+ self .assertArraysEqual (out3 , np_inp ** 2 )
3901
+ self .assertArraysEqual (out4 , np_inp * 3 )
3902
+ self .assertArraysEqual (out5 , np_inp .T )
3903
+
3884
3904
3885
3905
class TempSharding (Sharding ):
3886
3906
@@ -4314,24 +4334,6 @@ def test_array_mapping_to_axis_resources(self, inp, expected_out):
4314
4334
sharding_impls .array_mapping_to_axis_resources (inp ), expected_out
4315
4335
)
4316
4336
4317
- @parameterized .named_parameters (
4318
- ("all_unspecified" , (UNSPECIFIED , UNSPECIFIED ), AssertionError ),
4319
- ("only_unspecified" , UNSPECIFIED ),
4320
- ("all_specified" , (P ('x' ), P ('y' ))),
4321
- ("only_specified" , P ('x' )),
4322
- ("mix_1" , (P ('x' ), UNSPECIFIED ), ValueError ),
4323
- ("mix_2" , (P ('x' ), UNSPECIFIED , P ('y' )), ValueError ),
4324
- ("mix_3" , (UNSPECIFIED , P ('x' ), P ('y' )), ValueError ),
4325
- ("mix_4" , (UNSPECIFIED , P ('x' ), UNSPECIFIED ), ValueError ),
4326
- )
4327
- def test_all_or_non_unspecified (self , axis_resources , error = None ):
4328
- entries , _ = jax .tree .flatten (axis_resources , is_leaf = lambda x : x is None )
4329
- if error is not None :
4330
- with self .assertRaises (error ):
4331
- sharding_impls .check_all_or_none_unspecified (entries , 'test axis resources' )
4332
- else :
4333
- sharding_impls .check_all_or_none_unspecified (entries , 'test axis resources' )
4334
-
4335
4337
def test_op_sharding_equality_and_hash_equality (self ):
4336
4338
op1 = xc .OpSharding ()
4337
4339
op1 .type = xc .OpSharding .Type .OTHER
0 commit comments