-
Notifications
You must be signed in to change notification settings - Fork 427
Add option for selective op AC to filter mm shapes based on fqn #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9e3b49b
to
3c4d97d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great idea!
Since this feature is advanced, could you also help test if the behavior is expected?
It seems this feature does not require distributed, so maybe we can add a unit test file in
https://github.com/pytorch/torchtitan/tree/main/tests/unit_tests
But if it doesn't make sense, feel free to do it in the way you prefer.
torchtitan/config_manager.py
Outdated
@@ -487,6 +487,20 @@ class ActivationCheckpoint: | |||
'int' (e.g., 2) for every nth layer, or 'op' for op level ac. | |||
""" | |||
|
|||
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field( | |||
default_factory=lambda: [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems good enough to
default_factory=lambda: [] | |
default_factory=list |
or we can default to ["moe.router.gate"]
so that we don't need to define it in a lot of tomls.
O/w could you please also update the tomls in https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/llama4/train_configs
and
https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/deepseek_v3/train_configs
if ( | ||
fqn | ||
not in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that in float8, we also filter by fqns, in which we are doing reversely
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/utils.py#L25
I think one reason could be that the filter over there is applied to the whole model, so one fqn can help map to multiple layers / modules.
I think for AC there's not that much difference between the two. The benefit of doing it the other way may be users don't need to specify accurately the full relative fqn within the AC region. E.g. "router.gate" would also work.
I don't have a strong preference, but maybe let's be consistent with float8 if you don't have strong preference either.
torchtitan/config_manager.py
Outdated
@@ -487,6 +487,20 @@ class ActivationCheckpoint: | |||
'int' (e.g., 2) for every nth layer, or 'op' for op level ac. | |||
""" | |||
|
|||
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should prefer a shorter name over how accurate its meaning is. How about per_op_sac_filter_fqns
? Most users shouldn't really care about the details of implementation; if some users do, they can check the helper message and implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. To help reduce the cognitive load of parsing the config file for the average user who I agree won't care about the impl, would it help if the default already include any moe router fqns for TorchTitan models per your other suggestion? This means most configs won't need to contain it at all, so most users won't see it and the advanced users using it will still benefit from a more explicit name.
I think this is consistent with most users already not being aware of what the per-op sac policy is at all, although we could potentially refactor things such that we have specific policies like policy="compute_intensive_excluding_every_other_matmul"
3c4d97d
to
bfb3a32
Compare
Also see discussion in #1372
This PR: