8
8
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
9
10
10
from collections import defaultdict
11
+ from decimal import Decimal
12
+ import functools
11
13
12
14
import torch
13
15
import torch .nn as nn
@@ -237,6 +239,28 @@ def apply_tp(
237
239
}
238
240
239
241
242
+ def _format_mm_flops_table (entries ):
243
+ header = ("MM Shape" , "FLOPs" , "Count" )
244
+ rows = [header ] + [(k [0 ], k [1 ], str (v )) for k , v in entries .items ()]
245
+ col0 = max (len (r [0 ]) for r in rows )
246
+ col1 = max (len (r [1 ]) for r in rows )
247
+ col2 = max (len (r [2 ]) for r in rows )
248
+ lines = [
249
+ f"| { 'MM Shape' .ljust (col0 )} | { 'FLOPs' .ljust (col1 )} | { 'Count' .ljust (col2 )} |" ,
250
+ f"| { '-' * col0 } | { '-' * col1 } | { '-' * col2 } |" ,
251
+ ]
252
+ for s , fl , cnt in rows [1 :]:
253
+ lines .append (f"| { s .ljust (col0 )} | { fl .ljust (col1 )} | { cnt .ljust (col2 )} |" )
254
+ return "\n " .join (lines )
255
+
256
+
257
+ def _wrap_with_disable_early_stop (fn ):
258
+ def inner (* args , ** kwargs ):
259
+ with torch .utils .checkpoint .set_checkpoint_early_stop (False ):
260
+ return fn (* args , ** kwargs )
261
+ return inner
262
+
263
+
240
264
def _apply_ac_to_transformer_block (module : nn .Module , ac_config ):
241
265
valid_ac_modes = ("full" , "selective" )
242
266
if ac_config .mode not in valid_ac_modes :
@@ -264,12 +288,38 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
264
288
def _get_custom_policy (meta ):
265
289
def _custom_policy (ctx , func , * args , ** kwargs ):
266
290
mode = "recompute" if ctx .is_recompute else "forward"
291
+
292
+ mm_count_key_filtered = f"{ mode } _mm_count_filtered"
267
293
mm_count_key = f"{ mode } _mm_count"
294
+
268
295
if func == torch .ops .aten .mm .default :
269
296
meta [mm_count_key ] += 1
297
+
298
+ m , k = args [0 ].shape
299
+ k2 , n = args [1 ].shape
300
+ assert k == k2
301
+ flops = m * n * 2 * k
302
+
303
+ if ac_config .log_mm_flops and ctx .is_recompute :
304
+ shape_str = f"({ m } x{ k } ) x ({ k2 } x{ n } )"
305
+ flops_str = f"{ Decimal (flops ):.2E} "
306
+ key = (shape_str , flops_str )
307
+ meta [key ] += 1
308
+ if meta ["recompute_mm_count" ] == meta ["forward_mm_count" ]:
309
+ table = _format_mm_flops_table ({k :v for k ,v in meta .items () if "mm_count" not in k })
310
+ logger .info ("\n %s" , table )
311
+
312
+ # Filter out ops below are certain flop threshold. See discussion for why we
313
+ # recompute instead of save here:
314
+ # https://github.com/pytorch/torchtitan/pull/1372#discussion_r2193722200
315
+ if flops < ac_config .selective_op_ac_mm_flops_threshold :
316
+ return CheckpointPolicy .PREFER_RECOMPUTE
317
+
318
+ meta [mm_count_key_filtered ] += 1
319
+
270
320
# Saves output of all compute ops, except every second mm
271
321
to_save = func in _save_list and not (
272
- func == torch .ops .aten .mm .default and meta [mm_count_key ] % 2 == 0
322
+ func == torch .ops .aten .mm .default and meta [mm_count_key_filtered ] % 2 == 0
273
323
)
274
324
return (
275
325
CheckpointPolicy .MUST_SAVE
@@ -283,9 +333,17 @@ def selective_checkpointing_context_fn():
283
333
meta = defaultdict (int )
284
334
return create_selective_checkpoint_contexts (_get_custom_policy (meta ))
285
335
336
+ checkpoint_fn = functools .partial (torch .utils .checkpoint .checkpoint , use_reentrant = False )
337
+ if ac_config .log_mm_flops :
338
+ # If early-stop is enabled, fewer mm are recomputed than in forward. Disabling
339
+ # this will slightly alter perf, but allows us to deterministically know when to
340
+ # log the mm flops table rather than having to spam it for every mm call.
341
+ checkpoint_fn = _wrap_with_disable_early_stop (checkpoint_fn )
342
+
286
343
return ptd_checkpoint_wrapper (
287
344
module ,
288
345
context_fn = selective_checkpointing_context_fn ,
346
+ checkpoint_fn = checkpoint_fn ,
289
347
preserve_rng_state = False ,
290
348
)
291
349
elif use_layer_sac :
@@ -305,7 +363,9 @@ def apply_ac(model: nn.Module, ac_config):
305
363
transformer_block = _apply_ac_to_transformer_block (transformer_block , ac_config )
306
364
model .layers .register_module (layer_id , transformer_block )
307
365
308
- logger .info (f"Applied { ac_config .mode } activation checkpointing to the model" )
366
+ logger .info (f"Applied { ac_config .mode } checkpointing to the model" )
367
+ if ac_config .selective_ac_option == "op" and ac_config .log_mm_flops :
368
+ logger .info ("Logging enabled for mm flops." )
309
369
310
370
311
371
def apply_compile (model : nn .Module ):
0 commit comments