Skip to content

Commit a8653c9

Browse files
authored
disable tests using mark_sharding + assume_pure on GPU (#9013)
1 parent 1b7422d commit a8653c9

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/scan/test_scan_spmd.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import sys
23
import re
34
import unittest
@@ -230,11 +231,19 @@ def check_dots_in_model(self, model, x, expect_pattern):
230231
def count_regex(self, hlo_text, regex_str):
231232
return len(re.findall(regex_str, hlo_text))
232233

234+
@unittest.skipIf(
235+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
236+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
237+
)
233238
def test_assume_pure_works_with_mark_sharding(self):
234239
x = torch.randn((3, 4, 5, 128), device='xla')
235240
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
236241
# assert not throwing
237242

243+
@unittest.skipIf(
244+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
245+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
246+
)
238247
def test_convert_to_jax_mesh(self):
239248
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
240249
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)

0 commit comments

Comments
 (0)