Skip to content

Add option for selective op AC to filter mm shapes based on fqn #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

soulitzer
Copy link

@soulitzer soulitzer commented Jul 11, 2025

Also see discussion in #1372

This PR:

  • Adds new config for SAC with the default such that per-op SAC automatically skips all mms with args[1].shape matching that of the Linear at fqn "moe.router.gate"
  • Adds general flop/act-mem/correctness tests for AC as well as the new config

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 11, 2025
@soulitzer soulitzer force-pushed the soulitzer/add-sac-psuedo-fqn-policy branch 4 times, most recently from 9e3b49b to 3c4d97d Compare July 11, 2025 14:38
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea!
Since this feature is advanced, could you also help test if the behavior is expected?

It seems this feature does not require distributed, so maybe we can add a unit test file in
https://github.com/pytorch/torchtitan/tree/main/tests/unit_tests

But if it doesn't make sense, feel free to do it in the way you prefer.

@@ -487,6 +487,20 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(
default_factory=lambda: []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good enough to

Suggested change
default_factory=lambda: []
default_factory=list

or we can default to ["moe.router.gate"] so that we don't need to define it in a lot of tomls.

O/w could you please also update the tomls in https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/llama4/train_configs
and
https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/deepseek_v3/train_configs

Comment on lines 267 to 273
if (
fqn
not in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in float8, we also filter by fqns, in which we are doing reversely
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/utils.py#L25
I think one reason could be that the filter over there is applied to the whole model, so one fqn can help map to multiple layers / modules.

I think for AC there's not that much difference between the two. The benefit of doing it the other way may be users don't need to specify accurately the full relative fqn within the AC region. E.g. "router.gate" would also work.

I don't have a strong preference, but maybe let's be consistent with float8 if you don't have strong preference either.

@@ -487,6 +487,20 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should prefer a shorter name over how accurate its meaning is. How about per_op_sac_filter_fqns? Most users shouldn't really care about the details of implementation; if some users do, they can check the helper message and implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. To help reduce the cognitive load of parsing the config file for the average user who I agree won't care about the impl, would it help if the default already include any moe router fqns for TorchTitan models per your other suggestion? This means most configs won't need to contain it at all, so most users won't see it and the advanced users using it will still benefit from a more explicit name.

I think this is consistent with most users already not being aware of what the per-op sac policy is at all, although we could potentially refactor things such that we have specific policies like policy="compute_intensive_excluding_every_other_matmul"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants