Skip to content

Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO #30258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,58 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
return "v6 lite" in device_kind
return expected_version in device_kind

def pattern_search(patterns: str | Sequence[str], string: str):
if not isinstance(patterns, tuple):
patterns = (patterns,) # type: ignore

for pattern in patterns:
if pattern in string:
return pattern
return None

def device_kind_matches(device_patterns: str | Sequence[str]):
device_kind = xla_bridge.devices()[0].device_kind
matching_pattern = pattern_search(device_patterns, device_kind)
return matching_pattern is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just return pattern_search(...) is not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt that might be too dense for one line to an unfamiliar reader, I tend to assign intermediate variables just for readability. But not a strong preference.


def skip_if_errors(
*,
error_patterns: str | Sequence[str],
device_patterns: str | Sequence[str],
reason: str | Callable[[str, str], str],
):
"""Skip if both error message and device kind match a corresponding pattern."""
def skip(test_method):
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
device_kind = xla_bridge.devices()[0].device_kind
try:
return test_method(self, *args, **kwargs)
except Exception as e:
matching_error_pattern = pattern_search(error_patterns, str(e))
matching_device_pattern = pattern_search(device_patterns, device_kind)
if matching_error_pattern and matching_device_pattern:
if not isinstance(reason, str):
reason_str = reason(matching_error_pattern, matching_device_pattern)
else:
reason_str = reason
self.skipTest(reason_str)
raise
return test_method_wrapper
return skip

skip_if_mosaic_gpu_exceeds_shared_memory = functools.partial(
skip_if_errors,
error_patterns="kernel exceeds available shared memory",
reason=lambda err, dev: f"Mosaic GPU kernel exceeds shared memory on {dev}",
)

skip_if_triton_exceeds_shared_memory = functools.partial(
skip_if_errors,
error_patterns="Shared memory size limit exceeded",
reason=lambda err, dev: f"Triton kernel exceeds shared memory on {dev}",
)

def is_cuda_compute_capability_at_least(capability: str) -> bool:
if not is_device_cuda():
return False
Expand Down
20 changes: 18 additions & 2 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def kernel(ctx, inp, out, smem):
fa.WGMMA_LAYOUT_UPCAST_4X,
),
)
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_optimized_conversion(self, jax_dtype_from_to, layout):
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
Expand Down Expand Up @@ -2752,6 +2753,7 @@ def kernel(ctx, src, dst, _):
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))

@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_strided_load_store(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
Expand Down Expand Up @@ -3118,7 +3120,13 @@ def get_reg(addr):
return addr[:pos]
return addr
used_regs = {get_reg(addr) for addr in addrs}
self.assertLessEqual(len(used_regs), expected_regs)
try:
self.assertLessEqual(len(used_regs), expected_regs)
except:
problematic_device = "RTX PRO 6000 Blackwell"
if jtu.device_kind_matches(problematic_device):
self.skipTest(f"{problematic_device} uses more registers for an unknown reason")
raise

def test_copy_for_upcast(self):
dtype = jnp.int8
Expand Down Expand Up @@ -3154,6 +3162,7 @@ def kernel(ctx, in_, out, smems):
(fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT),
],
)
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_transpose_tiled(self, dtype, swizzle, layouts):
mlir_dtype = utils.dtype_to_ir_type(dtype)
bw = bytewidth(mlir_dtype)
Expand Down Expand Up @@ -3198,6 +3207,7 @@ def kernel(ctx, in_, out, smems):
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
)
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_upcast_to_wgmma(
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
):
Expand Down Expand Up @@ -3245,7 +3255,13 @@ def tile(x, tiling):
yt_kernel = f(xt)
jax.block_until_ready(yt_kernel)
np.testing.assert_array_equal(yt_kernel, yt)
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
try:
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
except:
problematic_device = "RTX PRO 6000 Blackwell"
if jtu.device_kind_matches(problematic_device):
self.skipTest(f"{problematic_device} requires more SHFL.BFLY for an unknown reason")
raise


@dataclasses.dataclass(frozen=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def kernel(x_ref, y_ref, o_ref):
np.testing.assert_array_equal(kernel(x, y), x + y[0])

@parameterized.product(shape=[(128,), (128, 128)])
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_reduce_sum(self, shape):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32)
Expand Down Expand Up @@ -719,6 +720,7 @@ def kernel(x_ref, o_ref, barrier_ref):
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x)

@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_scoped_copy_with_transforms(self):
self.skip_if_wg_semantics()

Expand All @@ -742,6 +744,7 @@ def body(tmp_ref):
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x * 2)

@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_scoped_copy_with_user_transforms(self):
def kernel(x_ref, o_ref, barrier_ref):
def body(tmp_ref):
Expand All @@ -762,6 +765,7 @@ def body(tmp_ref):
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
np.testing.assert_array_equal(f(x), x * 2)

@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_copy_with_transforms_and_indexing(self):
self.skip_if_wg_semantics()

Expand Down Expand Up @@ -811,6 +815,7 @@ def kernel(x_ref, o_ref):
x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128)
np.testing.assert_array_equal(kernel(x), x)

@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_indexing_before_transpose(self):
self.skip_if_wg_semantics()

Expand Down Expand Up @@ -2021,6 +2026,7 @@ def kernel(x_ref, y_ref, smem_ref, smem_out_ref, barrier_ref):
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(x[:, None], (128, 128)))

@parameterized.named_parameters((l.name.lower(), l) for l in plgpu.Layout)
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_copy_layout(self, layout):
self.skip_if_wg_semantics()
if layout in {
Expand Down
3 changes: 3 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def kernel(*refs):
for name, func, strategy in UNARY_FUNCTIONS
)
@hp.given(hps.data())
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_unary_primitives(self, name, func, shape_dtype_strategy, data):
if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]:
self.skip_if_mosaic_gpu()
Expand Down Expand Up @@ -1897,6 +1898,8 @@ def f(x_ref, o_ref):
trans_x=[False, True],
trans_y=[False, True],
)
@jtu.skip_if_triton_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
self.skip_if_mosaic_gpu()

Expand Down
Loading