Skip to content

Add option to exclude low flop mms from every-other-mm sac policy #1372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/soulitzer/2/base
Choose a base branch
from

Conversation

soulitzer
Copy link

@soulitzer soulitzer commented Jul 8, 2025

Stack from ghstack (oldest at bottom):

See issue described in #1182 (comment) / #1330 (comment)

Manual testing (on llama3 8b):

Without PR:
[rank0]:[titan] 2025-07-09 09:12:06,021 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 69.06GiB(87.27%) tps: 401 tflops: 23.22 mfu: 7.44%

With selective_op_ac_mm_flops_threshold=1e50:
[rank0]:[titan] 2025-07-09 09:09:56,360 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 58.42GiB(73.82%) tps: 389 tflops: 22.52 mfu: 7.22%

With selective_op_ac_mm_flops_threshold=0:
[rank0]:[titan] 2025-07-09 09:10:50,693 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 69.06GiB(87.27%) tps: 399 tflops: 23.10 mfu: 7.40%

Manual exclude all mms from being saved:
[rank0]:[titan] 2025-07-09 09:12:56,345 - root - INFO - step: 1 loss: 12.2236 grad_norm: 4.1079 memory: 58.42GiB(73.82%) tps: 394 tflops: 22.84 mfu: 7.32%

Also adds the ability to log the mm shapes/flops:
llama 3 8b

[rank0]:| MM Shape                    | FLOPs    | Count |
[rank0]:| --------------------------- | -------- | ----- |
[rank0]:| (8192x4096) x (4096x4096)   | 2.75E+11 | 2     |
[rank0]:| (8192x4096) x (4096x1024)   | 6.87E+10 | 2     |
[rank0]:| (8192x4096) x (4096x14336)  | 9.62E+11 | 2     |
[rank0]:| (8192x14336) x (14336x4096) | 9.62E+11 | 1     |

llama 4 debug

[rank0]:[titan] 2025-07-10 08:35:02,476 - root - INFO -
[rank0]:| MM Shape                | FLOPs   | Count |
[rank0]:| ----------------------- | ------- | ----- |
[rank0]:| (16384x256) x (256x256) | 2.15E+9 | 4     |
[rank0]:| (16384x256) x (256x8)   | 6.71E+7 | 1     |
[rank0]:| (16x256) x (256x512)    | 4.19E+6 | 10    |
[rank0]:| (16x512) x (512x256)    | 4.19E+6 | 5     |
[rank0]:| (15104x256) x (256x512) | 3.96E+9 | 2     |
[rank0]:| (15104x512) x (512x256) | 3.96E+9 | 1     |
[rank0]:| (352x256) x (256x512)   | 9.23E+7 | 2     |
[rank0]:| (352x512) x (512x256)   | 9.23E+7 | 1     |
[rank0]:| (928x256) x (256x512)   | 2.43E+8 | 2     |
[rank0]:| (928x512) x (512x256)   | 2.43E+8 | 1     |
[rank0]:[titan] 2025-07-10 08:35:02,606 - root - INFO -
[rank0]:| MM Shape                | FLOPs   | Count |
[rank0]:| ----------------------- | ------- | ----- |
[rank0]:| (16384x256) x (256x256) | 2.15E+9 | 4     |
[rank0]:| (16384x256) x (256x8)   | 6.71E+7 | 1     |
[rank0]:| (5696x256) x (256x512)  | 1.49E+9 | 2     |
[rank0]:| (5696x512) x (512x256)  | 1.49E+9 | 1     |
[rank0]:| (192x256) x (256x512)   | 5.03E+7 | 2     |
[rank0]:| (192x512) x (512x256)   | 5.03E+7 | 1     |
[rank0]:| (16x256) x (256x512)    | 4.19E+6 | 6     |
[rank0]:| (16x512) x (512x256)    | 4.19E+6 | 3     |
[rank0]:| (5504x256) x (256x512)  | 1.44E+9 | 2     |
[rank0]:| (5504x512) x (512x256)  | 1.44E+9 | 1     |
[rank0]:| (4976x256) x (256x512)  | 1.30E+9 | 2     |
[rank0]:| (4976x512) x (512x256)  | 1.30E+9 | 1     |
[rank0]:| (32x256) x (256x512)    | 8.39E+6 | 2     |
[rank0]:| (32x512) x (512x256)    | 8.39E+6 | 1     |
[rank0]:[titan] 2025-07-10 08:35:02,736 - root - INFO -
[rank0]:| MM Shape                | FLOPs   | Count |
[rank0]:| ----------------------- | ------- | ----- |
[rank0]:| (16384x256) x (256x256) | 2.15E+9 | 4     |
[rank0]:| (16384x256) x (256x8)   | 6.71E+7 | 1     |
[rank0]:| (1392x256) x (256x512)  | 3.65E+8 | 2     |
[rank0]:| (1392x512) x (512x256)  | 3.65E+8 | 1     |
[rank0]:| (1008x256) x (256x512)  | 2.64E+8 | 2     |
[rank0]:| (1008x512) x (512x256)  | 2.64E+8 | 1     |
[rank0]:| (2960x256) x (256x512)  | 7.76E+8 | 2     |
[rank0]:| (2960x512) x (512x256)  | 7.76E+8 | 1     |
[rank0]:| (400x256) x (256x512)   | 1.05E+8 | 2     |
[rank0]:| (400x512) x (512x256)   | 1.05E+8 | 1     |
[rank0]:| (2336x256) x (256x512)  | 6.12E+8 | 2     |
[rank0]:| (2336x512) x (512x256)  | 6.12E+8 | 1     |
[rank0]:| (4144x256) x (256x512)  | 1.09E+9 | 2     |
[rank0]:| (4144x512) x (512x256)  | 1.09E+9 | 1     |
[rank0]:| (1024x256) x (256x512)  | 2.68E+8 | 2     |
[rank0]:| (1024x512) x (512x256)  | 2.68E+8 | 1     |
[rank0]:| (3168x256) x (256x512)  | 8.30E+8 | 2     |
[rank0]:| (3168x512) x (512x256)  | 8.30E+8 | 1     |

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Jul 8, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 8, 2025
[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Jul 8, 2025
[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Jul 8, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 9, 2025

@danielvegamyhre @H-Huang
Could you help review this PR? I'm not sure if it's better to hardcore a fixed threshold or let user to choose.

Also one scenario I could imagine is playing with different shapes of input / model size during debugging, but this option would make the behavior "silently" change. Maybe it's better to have a mechanism to filter by fqn, like what we did for float8 training https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/float8.py#L56

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing, could you fix the seed and verify the loss are the same with / without the threshold taking effect? Verification only in eager mode is fine.

btw ghstack may not work in torchtitan, sadly

assert k == k2
flops = m * n * 2 * k
if flops < ac_config.selective_op_ac_mm_flops_threshold:
return CheckpointPolicy.PREFER_RECOMPUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm guessing because matmul FLOPs required grows cubically / O(n^3), but the output activation grows only quadratically / O(n^2), so when n is small the FLOPS required to recompute are relatively small compared to the size of the output activation. In contrast, when n is larger, the FLOPs to recompute has grown cubically but memory saved has only grown quadratically, so the trade-off for saving instead of recomputing becomes more favorable.

If so, I think this change is a good idea, as long as it is configurable and has a sensible default that is documented clearly.

Copy link
Author

@soulitzer soulitzer Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls?

Good question, relative to non-matmuls, perhaps small matmuls still do a disproportionate amount of compute relative to output size, so there could definitely be benefit to saving as well.

Although as Daniel mentions, relative to large matmuls, small matmuls indeed have a more favorable memory/compute trade off for recomputing, so intuitively it might be more pareto optimal to recompute, e.g.

If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul, rather spending that same amount of memory saving a bunch of smaller ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul

oh this is great argument, I got the idea.

@soulitzer
Copy link
Author

For testing, could you fix the seed and verify the loss are the same with / without the threshold taking effect? Verification only in eager mode is fine.

@tianyu-l Updated the results in the description to use the same seed!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a "sensible default"?

E.g. it doesn't filter out anything in llama 3 8b/70b/405b, and only filters out the router gate in llama4 17x16e / 17x128e?

I think it's a bit tricky as it depends on the input shape [batch size x sequence length].

cc @danielvegamyhre any suggestions?

assert k == k2
flops = m * n * 2 * k
if flops < ac_config.selective_op_ac_mm_flops_threshold:
return CheckpointPolicy.PREFER_RECOMPUTE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intuitively it might be more pareto optimal to recompute, e.g.
If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul

oh this is great argument, I got the idea.

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Jul 10, 2025
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
selective_op_ac_mm_flops_threshold = 0 # checking if everything is recomputed
log_mm_flops = true
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, didn't mean to include this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants