Skip to content

Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO #30258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

andportnoy
Copy link
Contributor

Depends on #30257.

@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 1afc6a0 to 4afbf78 Compare July 16, 2025 18:31
return test_method(self, *args, **kwargs)
except Exception as e:
mosaic_gpu_pattern = "kernel exceeds available shared memory"
triton_pattern = "Shared memory size limit exceeded"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very mosaic GPU/triton specific. Maybe it would be better to have a skip_if_errors(device_patterns, messages) and specialize the message in the pallas tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I factored the generic logic out but kept the specialized versions in test_util.py to avoid copying them over to too many files. What do you think?

@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 4afbf78 to 549d2a4 Compare July 17, 2025 17:31
@andportnoy andportnoy requested a review from justinjfu July 17, 2025 17:33
Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 17, 2025

def skip_if_device_kind_matches(device_patterns: str | Sequence[str]):
"""A decorator for test methods to skip the test when run on specified devices."""
device_kind = jax.devices()[0].device_kind
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is breaking our internal CI - can you move this line and the one below inside of the skip function below? We can't query jax.devices() at the top-level of a python file.

@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 549d2a4 to 71f6976 Compare July 17, 2025 19:53
patterns = (patterns,) # type: ignore

for pattern in patterns:
if re.search(pattern, string):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply pattern in string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -468,6 +468,64 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
return "v6 lite" in device_kind
return expected_version in device_kind

def pattern_search(patterns: str | Sequence[str], string: str):
if isinstance(patterns, list):
patterns = tuple(patterns)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to cast it to a tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

def skip_if_device_kind_matches(device_patterns: str | Sequence[str]):
"""A decorator for test methods to skip the test when run on specified devices."""
def skip(test_method):
device_kind = jax.devices()[0].device_kind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the other case, this can't query jax.devices() before the test_method is actually invoked

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with a different approach.

@@ -3070,6 +3072,7 @@ def kernel(ctx, dst, _):
row_tiling=[8, 64],
)
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
@jtu.skip_if_device_kind_matches("RTX PRO 6000 Blackwell")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments explaining why those tests have to be skipped on this HW? It's clear that lack of SMEM might be a problem, but why is this one (and multiple others) failing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to skip with an explicit reason wherever it's needed.

@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 71f6976 to 1b3c120 Compare July 22, 2025 19:40
@andportnoy andportnoy requested a review from apaszke July 22, 2025 19:42
def device_kind_matches(device_patterns: str | Sequence[str]):
device_kind = jax.devices()[0].device_kind
matching_pattern = pattern_search(device_patterns, device_kind)
return matching_pattern is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just return pattern_search(...) is not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt that might be too dense for one line to an unfamiliar reader, I tend to assign intermediate variables just for readability. But not a strong preference.

@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 1b3c120 to 0b0ac35 Compare July 23, 2025 17:39
@andportnoy andportnoy added the CI Optional GPU Presubmit Label to flag PR to run additional GPU testing not in standard presubmits label Jul 23, 2025
@andportnoy andportnoy force-pushed the aportnoy/skip-tests-too-large-for-rtx-6000-pro branch from 0b0ac35 to 240c2d0 Compare July 23, 2025 18:15
@copybara-service copybara-service bot merged commit e7a3298 into jax-ml:main Jul 24, 2025
22 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Optional GPU Presubmit Label to flag PR to run additional GPU testing not in standard presubmits kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants