-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO #30258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO #30258
Conversation
1afc6a0
to
4afbf78
Compare
jax/_src/test_util.py
Outdated
return test_method(self, *args, **kwargs) | ||
except Exception as e: | ||
mosaic_gpu_pattern = "kernel exceeds available shared memory" | ||
triton_pattern = "Shared memory size limit exceeded" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very mosaic GPU/triton specific. Maybe it would be better to have a skip_if_errors(device_patterns, messages)
and specialize the message in the pallas tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I factored the generic logic out but kept the specialized versions in test_util.py
to avoid copying them over to too many files. What do you think?
4afbf78
to
549d2a4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
jax/_src/test_util.py
Outdated
|
||
def skip_if_device_kind_matches(device_patterns: str | Sequence[str]): | ||
"""A decorator for test methods to skip the test when run on specified devices.""" | ||
device_kind = jax.devices()[0].device_kind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is breaking our internal CI - can you move this line and the one below inside of the skip
function below? We can't query jax.devices() at the top-level of a python file.
549d2a4
to
71f6976
Compare
jax/_src/test_util.py
Outdated
patterns = (patterns,) # type: ignore | ||
|
||
for pattern in patterns: | ||
if re.search(pattern, string): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply pattern in string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
jax/_src/test_util.py
Outdated
@@ -468,6 +468,64 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: | |||
return "v6 lite" in device_kind | |||
return expected_version in device_kind | |||
|
|||
def pattern_search(patterns: str | Sequence[str], string: str): | |||
if isinstance(patterns, list): | |||
patterns = tuple(patterns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to cast it to a tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
jax/_src/test_util.py
Outdated
def skip_if_device_kind_matches(device_patterns: str | Sequence[str]): | ||
"""A decorator for test methods to skip the test when run on specified devices.""" | ||
def skip(test_method): | ||
device_kind = jax.devices()[0].device_kind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in the other case, this can't query jax.devices() before the test_method is actually invoked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with a different approach.
tests/mosaic/gpu_test.py
Outdated
@@ -3070,6 +3072,7 @@ def kernel(ctx, dst, _): | |||
row_tiling=[8, 64], | |||
) | |||
@jtu.thread_unsafe_test() # Modifies ``os.environ``. | |||
@jtu.skip_if_device_kind_matches("RTX PRO 6000 Blackwell") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add comments explaining why those tests have to be skipped on this HW? It's clear that lack of SMEM might be a problem, but why is this one (and multiple others) failing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated to skip with an explicit reason wherever it's needed.
71f6976
to
1b3c120
Compare
def device_kind_matches(device_patterns: str | Sequence[str]): | ||
device_kind = jax.devices()[0].device_kind | ||
matching_pattern = pattern_search(device_patterns, device_kind) | ||
return matching_pattern is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just return pattern_search(...) is not None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt that might be too dense for one line to an unfamiliar reader, I tend to assign intermediate variables just for readability. But not a strong preference.
1b3c120
to
0b0ac35
Compare
0b0ac35
to
240c2d0
Compare
Depends on #30257.