-
Notifications
You must be signed in to change notification settings - Fork 425
Add option to exclude low flop mms from every-other-mm sac policy #1372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/soulitzer/2/base
Are you sure you want to change the base?
Conversation
@danielvegamyhre @H-Huang Also one scenario I could imagine is playing with different shapes of input / model size during debugging, but this option would make the behavior "silently" change. Maybe it's better to have a mechanism to filter by fqn, like what we did for float8 training https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/float8.py#L56 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing, could you fix the seed and verify the loss are the same with / without the threshold taking effect? Verification only in eager mode is fine.
btw ghstack may not work in torchtitan, sadly
assert k == k2 | ||
flops = m * n * 2 * k | ||
if flops < ac_config.selective_op_ac_mm_flops_threshold: | ||
return CheckpointPolicy.PREFER_RECOMPUTE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm guessing because matmul FLOPs required grows cubically / O(n^3), but the output activation grows only quadratically / O(n^2), so when n
is small the FLOPS required to recompute are relatively small compared to the size of the output activation. In contrast, when n
is larger, the FLOPs to recompute has grown cubically but memory saved has only grown quadratically, so the trade-off for saving instead of recomputing becomes more favorable.
If so, I think this change is a good idea, as long as it is configurable and has a sensible default that is documented clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?
Good question, relative to non-matmuls, perhaps small matmuls still do a disproportionate amount of compute relative to output size, so there could definitely be benefit to saving as well.
Although as Daniel mentions, relative to large matmuls, small matmuls indeed have a more favorable memory/compute trade off for recomputing, so intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul, rather spending that same amount of memory saving a bunch of smaller ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul
oh this is great argument, I got the idea.
@tianyu-l Updated the results in the description to use the same seed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide a "sensible default"?
E.g. it doesn't filter out anything in llama 3 8b/70b/405b, and only filters out the router gate in llama4 17x16e / 17x128e?
I think it's a bit tricky as it depends on the input shape [batch size x sequence length].
cc @danielvegamyhre any suggestions?
assert k == k2 | ||
flops = m * n * 2 * k | ||
if flops < ac_config.selective_op_ac_mm_flops_threshold: | ||
return CheckpointPolicy.PREFER_RECOMPUTE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul
oh this is great argument, I got the idea.
mode = "selective" # ["none", "selective", "full"] | ||
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy | ||
selective_op_ac_mm_flops_threshold = 0 # checking if everything is recomputed | ||
log_mm_flops = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, didn't mean to include this
Stack from ghstack (oldest at bottom):
See issue described in #1182 (comment) / #1330 (comment)
Manual testing (on llama3 8b):
Without PR:
[rank0]:[titan] 2025-07-09 09:12:06,021 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 69.06GiB(87.27%) tps: 401 tflops: 23.22 mfu: 7.44%
With
selective_op_ac_mm_flops_threshold=1e50
:[rank0]:[titan] 2025-07-09 09:09:56,360 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 58.42GiB(73.82%) tps: 389 tflops: 22.52 mfu: 7.22%
With
selective_op_ac_mm_flops_threshold=0
:[rank0]:[titan] 2025-07-09 09:10:50,693 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 69.06GiB(87.27%) tps: 399 tflops: 23.10 mfu: 7.40%
Manual exclude all mms from being saved:
[rank0]:[titan] 2025-07-09 09:12:56,345 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 58.42GiB(73.82%) tps: 394 tflops: 22.84 mfu: 7.32%
Also adds the ability to log the mm shapes/flops:
llama 3 8b
llama 4 debug