Skip to content

Commit 549d2a4

Browse files
committed
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO
1 parent 8ff5d33 commit 549d2a4

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed

jax/_src/test_util.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,65 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
468468
return "v6 lite" in device_kind
469469
return expected_version in device_kind
470470

471+
def pattern_search(patterns: str | Sequence[str], string: str):
472+
if isinstance(patterns, list):
473+
patterns = tuple(patterns)
474+
elif not isinstance(patterns, tuple):
475+
patterns = (patterns,) # type: ignore
476+
477+
for pattern in patterns:
478+
if re.search(pattern, string):
479+
return pattern
480+
return None
481+
482+
def skip_if_device_kind_matches(device_patterns: str | Sequence[str]):
483+
"""A decorator for test methods to skip the test when run on specified devices."""
484+
device_kind = jax.devices()[0].device_kind
485+
matching_pattern = pattern_search(device_patterns, device_kind)
486+
487+
reason = f"Device kind matches \"{matching_pattern}\""
488+
def skip(test_method):
489+
return unittest.skipIf(matching_pattern is not None, reason)(test_method)
490+
return skip
491+
492+
def skip_if_errors(
493+
*,
494+
error_patterns: str | Sequence[str],
495+
device_patterns: str | Sequence[str],
496+
reason: str | Callable[[str, str], str],
497+
):
498+
"""Skip if both error message and device kind match a corresponding pattern."""
499+
def skip(test_method):
500+
@functools.wraps(test_method)
501+
def test_method_wrapper(self, *args, **kwargs):
502+
device_kind = jax.devices()[0].device_kind
503+
try:
504+
return test_method(self, *args, **kwargs)
505+
except Exception as e:
506+
matching_error_pattern = pattern_search(error_patterns, str(e))
507+
matching_device_pattern = pattern_search(device_patterns, device_kind)
508+
if matching_error_pattern and matching_device_pattern:
509+
if not isinstance(reason, str):
510+
reason_str = reason(matching_error_pattern, matching_device_pattern)
511+
else:
512+
reason_str = reason
513+
self.skipTest(reason_str)
514+
raise
515+
return test_method_wrapper
516+
return skip
517+
518+
skip_if_mosaic_gpu_exceeds_shared_memory = functools.partial(
519+
skip_if_errors,
520+
error_patterns="kernel exceeds available shared memory",
521+
reason=lambda err, dev: f"Mosaic GPU kernel exceeds shared memory on {dev}",
522+
)
523+
524+
skip_if_triton_exceeds_shared_memory = functools.partial(
525+
skip_if_errors,
526+
error_patterns="Shared memory size limit exceeded",
527+
reason=lambda err, dev: f"Triton kernel exceeds shared memory on {dev}",
528+
)
529+
471530
def is_cuda_compute_capability_at_least(capability: str) -> bool:
472531
if not is_device_cuda():
473532
return False

tests/mosaic/gpu_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def kernel(ctx, inp, out, smem):
664664
fa.WGMMA_LAYOUT_UPCAST_4X,
665665
),
666666
)
667+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
667668
def test_optimized_conversion(self, jax_dtype_from_to, layout):
668669
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
669670
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
@@ -2752,6 +2753,7 @@ def kernel(ctx, src, dst, _):
27522753
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
27532754

27542755
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
2756+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
27552757
def test_strided_load_store(self, in_shape):
27562758
def kernel(ctx, *args):
27572759
gmem_input, gmem_output, (smem_input, smem_output) = args
@@ -3070,6 +3072,7 @@ def kernel(ctx, dst, _):
30703072
row_tiling=[8, 64],
30713073
)
30723074
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
3075+
@jtu.skip_if_device_kind_matches("RTX PRO 6000 Blackwell")
30733076
def test_copy_tiled(self, dtype, swizzle, num_col_tiles, row_tiling):
30743077
mlir_dtype = utils.dtype_to_ir_type(dtype)
30753078
bw = bytewidth(mlir_dtype)
@@ -3154,6 +3157,7 @@ def kernel(ctx, in_, out, smems):
31543157
(fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT),
31553158
],
31563159
)
3160+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
31573161
def test_transpose_tiled(self, dtype, swizzle, layouts):
31583162
mlir_dtype = utils.dtype_to_ir_type(dtype)
31593163
bw = bytewidth(mlir_dtype)
@@ -3198,6 +3202,7 @@ def kernel(ctx, in_, out, smems):
31983202
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
31993203
)
32003204
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
3205+
@jtu.skip_if_device_kind_matches("RTX PRO 6000 Blackwell")
32013206
def test_upcast_to_wgmma(
32023207
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
32033208
):

tests/pallas/mosaic_gpu_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def kernel(x_ref, y_ref, o_ref):
226226
np.testing.assert_array_equal(kernel(x, y), x + y[0])
227227

228228
@parameterized.product(shape=[(128,), (128, 128)])
229+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
229230
def test_reduce_sum(self, shape):
230231
@functools.partial(
231232
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32)
@@ -719,6 +720,7 @@ def kernel(x_ref, o_ref, barrier_ref):
719720
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
720721
np.testing.assert_array_equal(f(x), x)
721722

723+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
722724
def test_scoped_copy_with_transforms(self):
723725
self.skip_if_wg_semantics()
724726

@@ -742,6 +744,7 @@ def body(tmp_ref):
742744
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
743745
np.testing.assert_array_equal(f(x), x * 2)
744746

747+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
745748
def test_scoped_copy_with_user_transforms(self):
746749
def kernel(x_ref, o_ref, barrier_ref):
747750
def body(tmp_ref):
@@ -762,6 +765,7 @@ def body(tmp_ref):
762765
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
763766
np.testing.assert_array_equal(f(x), x * 2)
764767

768+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
765769
def test_copy_with_transforms_and_indexing(self):
766770
self.skip_if_wg_semantics()
767771

@@ -811,6 +815,7 @@ def kernel(x_ref, o_ref):
811815
x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128)
812816
np.testing.assert_array_equal(kernel(x), x)
813817

818+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
814819
def test_indexing_before_transpose(self):
815820
self.skip_if_wg_semantics()
816821

@@ -2021,6 +2026,7 @@ def kernel(x_ref, y_ref, smem_ref, smem_out_ref, barrier_ref):
20212026
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(x[:, None], (128, 128)))
20222027

20232028
@parameterized.named_parameters((l.name.lower(), l) for l in plgpu.Layout)
2029+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
20242030
def test_copy_layout(self, layout):
20252031
self.skip_if_wg_semantics()
20262032
if layout in {

tests/pallas/ops_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def kernel(*refs):
560560
for name, func, strategy in UNARY_FUNCTIONS
561561
)
562562
@hp.given(hps.data())
563+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
563564
def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
564565
if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]:
565566
self.skip_if_mosaic_gpu()
@@ -1897,6 +1898,8 @@ def f(x_ref, o_ref):
18971898
trans_x=[False, True],
18981899
trans_y=[False, True],
18991900
)
1901+
@jtu.skip_if_triton_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
1902+
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
19001903
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
19011904
self.skip_if_mosaic_gpu()
19021905

0 commit comments

Comments
 (0)