Skip to content

Commit 240c2d0

Browse files
committed
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO
1 parent 8ff5d33 commit 240c2d0

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

jax/_src/test_util.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,58 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
468468
return "v6 lite" in device_kind
469469
return expected_version in device_kind
470470

471+
def pattern_search(patterns: str | Sequence[str], string: str):
472+
if not isinstance(patterns, tuple):
473+
patterns = (patterns,) # type: ignore
474+
475+
for pattern in patterns:
476+
if pattern in string:
477+
return pattern
478+
return None
479+
480+
def device_kind_matches(device_patterns: str | Sequence[str]):
481+
device_kind = xla_bridge.devices()[0].device_kind
482+
matching_pattern = pattern_search(device_patterns, device_kind)
483+
return matching_pattern is not None
484+
485+
def skip_if_errors(
486+
*,
487+
error_patterns: str | Sequence[str],
488+
device_patterns: str | Sequence[str],
489+
reason: str | Callable[[str, str], str],
490+
):
491+
"""Skip if both error message and device kind match a corresponding pattern."""
492+
def skip(test_method):
493+
@functools.wraps(test_method)
494+
def test_method_wrapper(self, *args, **kwargs):
495+
device_kind = xla_bridge.devices()[0].device_kind
496+
try:
497+
return test_method(self, *args, **kwargs)
498+
except Exception as e:
499+
matching_error_pattern = pattern_search(error_patterns, str(e))
500+
matching_device_pattern = pattern_search(device_patterns, device_kind)
501+
if matching_error_pattern and matching_device_pattern:
502+
if not isinstance(reason, str):
503+
reason_str = reason(matching_error_pattern, matching_device_pattern)
504+
else:
505+
reason_str = reason
506+
self.skipTest(reason_str)
507+
raise
508+
return test_method_wrapper
509+
return skip
510+
511+
skip_if_mosaic_gpu_exceeds_shared_memory = functools.partial(
512+
skip_if_errors,
513+
error_patterns="kernel exceeds available shared memory",
514+
reason=lambda err, dev: f"Mosaic GPU kernel exceeds shared memory on {dev}",
515+
)
516+
517+
skip_if_triton_exceeds_shared_memory = functools.partial(
518+
skip_if_errors,
519+
error_patterns="Shared memory size limit exceeded",
520+
reason=lambda err, dev: f"Triton kernel exceeds shared memory on {dev}",
521+
)
522+
471523
def is_cuda_compute_capability_at_least(capability: str) -> bool:
472524
if not is_device_cuda():
473525
return False

tests/mosaic/gpu_test.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def kernel(ctx, inp, out, smem):
664664
fa.WGMMA_LAYOUT_UPCAST_4X,
665665
),
666666
)
667+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
667668
def test_optimized_conversion(self, jax_dtype_from_to, layout):
668669
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
669670
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
@@ -2752,6 +2753,7 @@ def kernel(ctx, src, dst, _):
27522753
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
27532754

27542755
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
2756+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
27552757
def test_strided_load_store(self, in_shape):
27562758
def kernel(ctx, *args):
27572759
gmem_input, gmem_output, (smem_input, smem_output) = args
@@ -3118,7 +3120,13 @@ def get_reg(addr):
31183120
return addr[:pos]
31193121
return addr
31203122
used_regs = {get_reg(addr) for addr in addrs}
3121-
self.assertLessEqual(len(used_regs), expected_regs)
3123+
try:
3124+
self.assertLessEqual(len(used_regs), expected_regs)
3125+
except:
3126+
problematic_device = "RTX PRO 6000 Blackwell"
3127+
if jtu.device_kind_matches(problematic_device):
3128+
self.skipTest(f"{problematic_device} uses more registers for an unknown reason")
3129+
raise
31223130

31233131
def test_copy_for_upcast(self):
31243132
dtype = jnp.int8
@@ -3154,6 +3162,7 @@ def kernel(ctx, in_, out, smems):
31543162
(fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT),
31553163
],
31563164
)
3165+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
31573166
def test_transpose_tiled(self, dtype, swizzle, layouts):
31583167
mlir_dtype = utils.dtype_to_ir_type(dtype)
31593168
bw = bytewidth(mlir_dtype)
@@ -3198,6 +3207,7 @@ def kernel(ctx, in_, out, smems):
31983207
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
31993208
)
32003209
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
3210+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
32013211
def test_upcast_to_wgmma(
32023212
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
32033213
):
@@ -3245,7 +3255,13 @@ def tile(x, tiling):
32453255
yt_kernel = f(xt)
32463256
jax.block_until_ready(yt_kernel)
32473257
np.testing.assert_array_equal(yt_kernel, yt)
3248-
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
3258+
try:
3259+
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
3260+
except:
3261+
problematic_device = "RTX PRO 6000 Blackwell"
3262+
if jtu.device_kind_matches(problematic_device):
3263+
self.skipTest(f"{problematic_device} requires more SHFL.BFLY for an unknown reason")
3264+
raise
32493265

32503266

32513267
@dataclasses.dataclass(frozen=True)

tests/pallas/mosaic_gpu_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def kernel(x_ref, y_ref, o_ref):
226226
np.testing.assert_array_equal(kernel(x, y), x + y[0])
227227

228228
@parameterized.product(shape=[(128,), (128, 128)])
229+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
229230
def test_reduce_sum(self, shape):
230231
@functools.partial(
231232
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32)
@@ -719,6 +720,7 @@ def kernel(x_ref, o_ref, barrier_ref):
719720
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
720721
np.testing.assert_array_equal(f(x), x)
721722

723+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
722724
def test_scoped_copy_with_transforms(self):
723725
self.skip_if_wg_semantics()
724726

@@ -742,6 +744,7 @@ def body(tmp_ref):
742744
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
743745
np.testing.assert_array_equal(f(x), x * 2)
744746

747+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
745748
def test_scoped_copy_with_user_transforms(self):
746749
def kernel(x_ref, o_ref, barrier_ref):
747750
def body(tmp_ref):
@@ -762,6 +765,7 @@ def body(tmp_ref):
762765
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
763766
np.testing.assert_array_equal(f(x), x * 2)
764767

768+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
765769
def test_copy_with_transforms_and_indexing(self):
766770
self.skip_if_wg_semantics()
767771

@@ -811,6 +815,7 @@ def kernel(x_ref, o_ref):
811815
x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128)
812816
np.testing.assert_array_equal(kernel(x), x)
813817

818+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
814819
def test_indexing_before_transpose(self):
815820
self.skip_if_wg_semantics()
816821

@@ -2021,6 +2026,7 @@ def kernel(x_ref, y_ref, smem_ref, smem_out_ref, barrier_ref):
20212026
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(x[:, None], (128, 128)))
20222027

20232028
@parameterized.named_parameters((l.name.lower(), l) for l in plgpu.Layout)
2029+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
20242030
def test_copy_layout(self, layout):
20252031
self.skip_if_wg_semantics()
20262032
if layout in {

tests/pallas/ops_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def kernel(*refs):
560560
for name, func, strategy in UNARY_FUNCTIONS
561561
)
562562
@hp.given(hps.data())
563+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
563564
def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
564565
if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]:
565566
self.skip_if_mosaic_gpu()
@@ -1897,6 +1898,8 @@ def f(x_ref, o_ref):
18971898
trans_x=[False, True],
18981899
trans_y=[False, True],
18991900
)
1901+
@jtu.skip_if_triton_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
1902+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
19001903
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
19011904
self.skip_if_mosaic_gpu()
19021905

0 commit comments

Comments
 (0)