-
Notifications
You must be signed in to change notification settings - Fork 427
Add option to exclude low flop mms from every-other-mm sac policy #1372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/soulitzer/2/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
# training techniques (e.g. activation checkpointing and compile) to the Llama model. | ||
|
||
from collections import defaultdict | ||
from decimal import Decimal | ||
import functools | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -237,6 +239,28 @@ def apply_tp( | |
} | ||
|
||
|
||
def _format_mm_flops_table(entries): | ||
header = ("MM Shape", "FLOPs", "Count") | ||
rows = [header] + [(k[0], k[1], str(v)) for k, v in entries.items()] | ||
col0 = max(len(r[0]) for r in rows) | ||
col1 = max(len(r[1]) for r in rows) | ||
col2 = max(len(r[2]) for r in rows) | ||
lines = [ | ||
f"| {'MM Shape'.ljust(col0)} | {'FLOPs'.ljust(col1)} | {'Count'.ljust(col2)} |", | ||
f"| {'-' * col0} | {'-' * col1} | {'-' * col2} |", | ||
] | ||
for s, fl, cnt in rows[1:]: | ||
lines.append(f"| {s.ljust(col0)} | {fl.ljust(col1)} | {cnt.ljust(col2)} |") | ||
return "\n".join(lines) | ||
|
||
|
||
def _wrap_with_disable_early_stop(fn): | ||
def inner(*args, **kwargs): | ||
with torch.utils.checkpoint.set_checkpoint_early_stop(False): | ||
return fn(*args, **kwargs) | ||
return inner | ||
|
||
|
||
def _apply_ac_to_transformer_block(module: nn.Module, ac_config): | ||
valid_ac_modes = ("full", "selective") | ||
if ac_config.mode not in valid_ac_modes: | ||
|
@@ -264,12 +288,38 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): | |
def _get_custom_policy(meta): | ||
def _custom_policy(ctx, func, *args, **kwargs): | ||
mode = "recompute" if ctx.is_recompute else "forward" | ||
|
||
mm_count_key_filtered = f"{mode}_mm_count_filtered" | ||
mm_count_key = f"{mode}_mm_count" | ||
|
||
if func == torch.ops.aten.mm.default: | ||
meta[mm_count_key] += 1 | ||
|
||
m, k = args[0].shape | ||
k2, n = args[1].shape | ||
assert k == k2 | ||
flops = m * n * 2 * k | ||
|
||
if ac_config.log_mm_flops and ctx.is_recompute: | ||
shape_str = f"({m}x{k}) x ({k2}x{n})" | ||
flops_str = f"{Decimal(flops):.2E}" | ||
key = (shape_str, flops_str) | ||
meta[key] += 1 | ||
if meta["recompute_mm_count"] == meta["forward_mm_count"]: | ||
table = _format_mm_flops_table({k:v for k,v in meta.items() if "mm_count" not in k}) | ||
logger.info("\n%s", table) | ||
|
||
# Filter out ops below are certain flop threshold. See discussion for why we | ||
# recompute instead of save here: | ||
# https://github.com/pytorch/torchtitan/pull/1372#discussion_r2193722200 | ||
if flops < ac_config.selective_op_ac_mm_flops_threshold: | ||
return CheckpointPolicy.PREFER_RECOMPUTE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a strong reason we use "always recompute" instead of "always save" for these small matmuls? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm guessing because matmul FLOPs required grows cubically / O(n^3), but the output activation grows only quadratically / O(n^2), so when If so, I think this change is a good idea, as long as it is configurable and has a sensible default that is documented clearly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good question, relative to non-matmuls, perhaps small matmuls still do a disproportionate amount of compute relative to output size, so there could definitely be benefit to saving as well. Although as Daniel mentions, relative to large matmuls, small matmuls indeed have a more favorable memory/compute trade off for recomputing, so intuitively it might be more pareto optimal to recompute, e.g. If we were to follow a greedy heuristic of only saving the op with the next best trade-off, since we're currently only saving every other large matmul, the most profitable next tensor to save should be another large matmul, rather spending that same amount of memory saving a bunch of smaller ones. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
oh this is great argument, I got the idea. |
||
|
||
meta[mm_count_key_filtered] += 1 | ||
|
||
# Saves output of all compute ops, except every second mm | ||
to_save = func in _save_list and not ( | ||
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 | ||
func == torch.ops.aten.mm.default and meta[mm_count_key_filtered] % 2 == 0 | ||
) | ||
return ( | ||
CheckpointPolicy.MUST_SAVE | ||
|
@@ -283,9 +333,17 @@ def selective_checkpointing_context_fn(): | |
meta = defaultdict(int) | ||
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) | ||
|
||
checkpoint_fn = functools.partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) | ||
if ac_config.log_mm_flops: | ||
# If early-stop is enabled, fewer mm are recomputed than in forward. Disabling | ||
# this will slightly alter perf, but allows us to deterministically know when to | ||
# log the mm flops table rather than having to spam it for every mm call. | ||
checkpoint_fn = _wrap_with_disable_early_stop(checkpoint_fn) | ||
|
||
return ptd_checkpoint_wrapper( | ||
module, | ||
context_fn=selective_checkpointing_context_fn, | ||
checkpoint_fn=checkpoint_fn, | ||
preserve_rng_state=False, | ||
) | ||
elif use_layer_sac: | ||
|
@@ -305,7 +363,9 @@ def apply_ac(model: nn.Module, ac_config): | |
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) | ||
model.layers.register_module(layer_id, transformer_block) | ||
|
||
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") | ||
logger.info(f"Applied {ac_config.mode} checkpointing to the model") | ||
if ac_config.selective_ac_option == "op" and ac_config.log_mm_flops: | ||
logger.info("Logging enabled for mm flops.") | ||
|
||
|
||
def apply_compile(model: nn.Module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, didn't mean to include this