@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618
618
619
619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620
620
# in non-spmd mode.
621
- @unittest .skipIf (xr . device_type () == 'CPU' ,
622
- "sharding will be the same for both tensors on single device"
623
- )
621
+ @unittest .skipIf (
622
+ xr . device_type () == 'CPU' ,
623
+ "sharding will be the same for both tensors on single device" )
624
624
def test_shard_hashing (self ):
625
625
xt1 = torch .ones (2 , 2 ).to (xm .xla_device ())
626
626
xt2 = torch .ones (2 , 2 ).to (xm .xla_device ())
@@ -1383,8 +1383,9 @@ def test_get_1d_mesh(self):
1383
1383
self .assertEqual (mesh_without_name .mesh_shape ,
1384
1384
(xr .global_runtime_device_count (),))
1385
1385
1386
- @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1387
- "Multiple devices required for dataloader sharding test" )
1386
+ @unittest .skipUnless (
1387
+ xr .global_runtime_device_count () > 1 ,
1388
+ "Multiple devices required for dataloader sharding test" )
1388
1389
def test_data_loader_with_sharding (self ):
1389
1390
device = torch_xla .device ()
1390
1391
mesh = xs .get_1d_mesh ("data" )
@@ -1405,8 +1406,9 @@ def test_data_loader_with_sharding(self):
1405
1406
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1406
1407
)
1407
1408
1408
- @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1409
- "Multiple devices required for dataloader sharding test" )
1409
+ @unittest .skipUnless (
1410
+ xr .global_runtime_device_count () > 1 ,
1411
+ "Multiple devices required for dataloader sharding test" )
1410
1412
def test_data_loader_with_non_batch_size (self ):
1411
1413
device = torch_xla .device ()
1412
1414
mesh = xs .get_1d_mesh ("data" )
@@ -1427,8 +1429,9 @@ def test_data_loader_with_non_batch_size(self):
1427
1429
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1428
1430
)
1429
1431
1430
- @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1431
- "Multiple devices required for dataloader sharding test" )
1432
+ @unittest .skipUnless (
1433
+ xr .global_runtime_device_count () > 1 ,
1434
+ "Multiple devices required for dataloader sharding test" )
1432
1435
def test_data_loader_with_non_batch_size_and_mini_batch (self ):
1433
1436
device = torch_xla .device ()
1434
1437
mesh = xs .get_1d_mesh ("data" )
@@ -1660,9 +1663,9 @@ def test_get_logical_mesh(self):
1660
1663
self .assertEqual (logical_mesh .shape , mesh_shape )
1661
1664
np .testing .assert_array_equal (np .sort (logical_mesh .flatten ()), device_ids )
1662
1665
1663
- @unittest .skipIf (xr . device_type () == 'CPU' ,
1664
- "sharding will be the same for both tensors on single device"
1665
- )
1666
+ @unittest .skipIf (
1667
+ xr . device_type () == 'CPU' ,
1668
+ "sharding will be the same for both tensors on single device" )
1666
1669
def test_shard_as (self ):
1667
1670
mesh = self ._get_mesh ((self .n_devices ,))
1668
1671
partition_spec = (0 ,)
0 commit comments