Skip to content

Commit 8e682cd

Browse files
committed
update yapf
1 parent 97dc6c3 commit 8e682cd

13 files changed

+50
-47
lines changed

CONTRIBUTING.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ find -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs clang-format-11 -i -st
169169
If your PR touches the Python source files, please run the following command before submitting a PR.
170170

171171
```Shell
172-
# How to install: pip install yapf==0.30.0
172+
# How to install: pip install yapf==0.40.2
173173
yapf --recursive -i *.py test/ scripts/ torch_xla/ benchmarks/
174174
```
175175

benchmarks/benchmark_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,8 @@ def is_compatible(self, dummy_benchmark_model: BenchmarkModel,
227227
def get_benchmark_indices(self, length: int):
228228
start = self._args.partition_id * (length // self._args.total_partitions)
229229
end = ((self._args.partition_id + 1) *
230-
(length // self._args.total_partitions)
231-
if self._args.partition_id < self._args.total_partitions - 1 else
232-
length)
230+
(length // self._args.total_partitions) if self._args.partition_id
231+
< self._args.total_partitions - 1 else length)
233232
return start, end
234233

235234
def skip_model(self, model_name: str):

infra/ansible/config/pip.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pip:
2828
- tqdm
2929
- typing_extensions
3030
- sympy
31-
- yapf==0.30.0
31+
- yapf==0.40.2
3232

3333
build_amd64:
3434
- mkl

test/pytorch_test_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ def skipped_test(self, *args, reason=reason, **kwargs):
619619
setattr(cls, dtype_test_name, disallowed_test)
620620
if not skipped:
621621
xla_dtypes.append(
622-
dtype_combination
623-
if len(dtype_combination) > 1 else dtype_combination[0])
622+
dtype_combination if len(dtype_combination) >
623+
1 else dtype_combination[0])
624624
if len(xla_dtypes) != 0:
625625
test.dtypes[cls.device_type] = xla_dtypes
626626
super().instantiate_test(name, test, generic_cls=generic_cls)

test/spmd/test_xla_sharding.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620620
# in non-spmd mode.
621-
@unittest.skipIf(xr.device_type() == 'CPU',
622-
"sharding will be the same for both tensors on single device"
623-
)
621+
@unittest.skipIf(
622+
xr.device_type() == 'CPU',
623+
"sharding will be the same for both tensors on single device")
624624
def test_shard_hashing(self):
625625
xt1 = torch.ones(2, 2).to(xm.xla_device())
626626
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1383,8 +1383,9 @@ def test_get_1d_mesh(self):
13831383
self.assertEqual(mesh_without_name.mesh_shape,
13841384
(xr.global_runtime_device_count(),))
13851385

1386-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1387-
"Multiple devices required for dataloader sharding test")
1386+
@unittest.skipUnless(
1387+
xr.global_runtime_device_count() > 1,
1388+
"Multiple devices required for dataloader sharding test")
13881389
def test_data_loader_with_sharding(self):
13891390
device = torch_xla.device()
13901391
mesh = xs.get_1d_mesh("data")
@@ -1405,8 +1406,9 @@ def test_data_loader_with_sharding(self):
14051406
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14061407
)
14071408

1408-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1409-
"Multiple devices required for dataloader sharding test")
1409+
@unittest.skipUnless(
1410+
xr.global_runtime_device_count() > 1,
1411+
"Multiple devices required for dataloader sharding test")
14101412
def test_data_loader_with_non_batch_size(self):
14111413
device = torch_xla.device()
14121414
mesh = xs.get_1d_mesh("data")
@@ -1427,8 +1429,9 @@ def test_data_loader_with_non_batch_size(self):
14271429
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14281430
)
14291431

1430-
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1431-
"Multiple devices required for dataloader sharding test")
1432+
@unittest.skipUnless(
1433+
xr.global_runtime_device_count() > 1,
1434+
"Multiple devices required for dataloader sharding test")
14321435
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14331436
device = torch_xla.device()
14341437
mesh = xs.get_1d_mesh("data")
@@ -1660,9 +1663,9 @@ def test_get_logical_mesh(self):
16601663
self.assertEqual(logical_mesh.shape, mesh_shape)
16611664
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16621665

1663-
@unittest.skipIf(xr.device_type() == 'CPU',
1664-
"sharding will be the same for both tensors on single device"
1665-
)
1666+
@unittest.skipIf(
1667+
xr.device_type() == 'CPU',
1668+
"sharding will be the same for both tensors on single device")
16661669
def test_shard_as(self):
16671670
mesh = self._get_mesh((self.n_devices,))
16681671
partition_spec = (0,)

test/test_operations.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -2959,11 +2959,9 @@ def test_dlpack_roundtrip_tensor(self, dtype):
29592959

29602960
@onlyIfTorchSupportsCUDA
29612961
@onlyIfPJRTDeviceIsCUDA
2962-
@parameterized.parameters(*all_types_and_complex_and(torch.half,
2963-
torch.bfloat16,
2964-
torch.bool, torch.uint16,
2965-
torch.uint32,
2966-
torch.uint64))
2962+
@parameterized.parameters(
2963+
*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool,
2964+
torch.uint16, torch.uint32, torch.uint64))
29672965
def test_dlpack_roundtrip_scalar(self, dtype):
29682966
xla_device = xm.xla_device()
29692967
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)

test/test_pallas.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45-
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46-
kv_segment_ids.shape[0], 1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45+
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46+
1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

test/test_pallas_spmd.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45-
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46-
kv_segment_ids.shape[0], 1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45+
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46+
1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

test/test_splash_attention.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def setUp(self):
6262

6363
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
6464
kv_segment_ids):
65-
return q_segment_ids.view(q_segment_ids.shape[0], 1,
66-
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
67-
kv_segment_ids.shape[0], 1, 1,
68-
kv_segment_ids.shape[1])
65+
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
66+
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
67+
1, 1,
68+
kv_segment_ids.shape[1])
6969

7070
def maybe_repeat_kv(self, hidden_state):
7171
if hidden_state.size(1) == self.NUM_Q_HEADS:

torch_xla/distributed/xla_multiprocessing.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def _v6e_create_replica_groups() -> List | None:
174174
return None
175175

176176

177-
device_kind_handler_dict: dict[str, Callable[..., List | None],] = {
177+
device_kind_handler_dict: dict[
178+
str,
179+
Callable[..., List | None],
180+
] = {
178181
_TPU_V5P: _v5p_create_replica_groups,
179182
_TPU_V6E: _v6e_create_replica_groups
180183
}

torch_xla/experimental/gradient_accumulation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def add_to_mapping(val: torch.Tensor,
288288
iterable_tensors, fake_iterable_tensors, carried_tensors,
289289
fake_carried_tensors, params, grads)
290290

291-
def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op,
292-
*while_params: xb.Op):
291+
def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, *while_params:
292+
xb.Op):
293293

294294
def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op:
295295
indices = [idx] + [idx.zeros_like() for _ in range(xs.shape().rank - 1)]

torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def make_sequence_metadata(
243243
#
244244
# Remove tile visits that belong to a sequence not in our shard.
245245
iota = jnp.arange(num_sequences, dtype=jnp.int32)
246-
active_sequence_mask = jnp.logical_and(iota <= end_sequence,
247-
iota >= start_sequence)
246+
active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota
247+
>= start_sequence)
248248
sequence_tiles = jnp.where(active_sequence_mask,
249249
sequence_tiles[:num_sequences], 0)
250250
num_tiles = sequence_tiles.sum()
@@ -375,8 +375,8 @@ def _flash_attention(
375375
logical_q_blk_idx - 1, 0)
376376
is_first_processed_logical_q_blk = logical_q_blk_idx == 0
377377
physical_q_blk_changed = (
378-
physical_q_tile_ids[logical_q_blk_idx] !=
379-
physical_q_tile_ids[prev_logical_q_blk_idx])
378+
physical_q_tile_ids[logical_q_blk_idx]
379+
!= physical_q_tile_ids[prev_logical_q_blk_idx])
380380
first_time_seeing_physical_q_blk = jnp.logical_or(
381381
is_first_processed_logical_q_blk, physical_q_blk_changed)
382382
is_first_kv_blk = (kv_blk_idx == 0)
@@ -509,8 +509,8 @@ def init_scratch_ref(): # pylint: disable=unused-variable
509509
logical_q_blk_idx + 1)
510510
is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1)
511511
physical_q_blk_will_change = (
512-
physical_q_tile_ids[logical_q_blk_idx] !=
513-
physical_q_tile_ids[next_logical_q_blk_idx])
512+
physical_q_tile_ids[logical_q_blk_idx]
513+
!= physical_q_tile_ids[next_logical_q_blk_idx])
514514
last_time_seeing_cur_physical_q_blk = jnp.logical_or(
515515
is_last_logical_q_blk, physical_q_blk_will_change)
516516
should_store_to_output = jnp.logical_and(is_last_kv_blk_idx,

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ def init_scratch_ref():
421421
)
422422
causal_mask = row_ids < col_ids
423423
if sliding_window is not None:
424-
causal_mask = jnp.logical_or(causal_mask,
425-
row_ids - sliding_window >= col_ids)
424+
causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window
425+
>= col_ids)
426426
if soft_cap is not None:
427427
qk = soft_cap * jnp.tanh(qk / soft_cap)
428428
qk += jnp.where(causal_mask, mask_value, 0.0)

0 commit comments

Comments
 (0)